"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "537cc635c77ac63f643c5289137debdd8f9591ac"
Unverified Commit 4ef95b0f authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] use float32 precision in samplers/test_logprobs.py for comparing with HF (#6409)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent eaec4b91
...@@ -11,7 +11,8 @@ MODELS = ["facebook/opt-125m"] ...@@ -11,7 +11,8 @@ MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False]) @pytest.mark.parametrize("detokenize", [True, False])
......
...@@ -687,6 +687,12 @@ if triton.__version__ >= "2.1.0": ...@@ -687,6 +687,12 @@ if triton.__version__ >= "2.1.0":
cap = current_platform.get_device_capability() cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64 BLOCK = 128 if cap[0] >= 8 else 64
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
BLOCK = BLOCK // 2
# shape constraints # shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment