"docs/vscode:/vscode.git/clone" did not exist on "d93d2d74fd807a091add17c2065ee8869339f76a"
Unverified Commit 17c1bdf3 authored by PatchyTIS's avatar PatchyTIS Committed by GitHub
Browse files

[Bugfix] dtype mismatch in ngram gpu propose (#37246)


Signed-off-by: default avatarPatchouliTaisa <patchychen@tencent.com>
Co-authored-by: default avatarPatchouliTaisa <patchychen@tencent.com>
parent 3e3d320c
...@@ -364,7 +364,9 @@ class NgramProposerGPU: ...@@ -364,7 +364,9 @@ class NgramProposerGPU:
) )
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter) token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count num_tokens_tmp = (num_tokens_no_spec + valid_sampled_tokens_count).to(
torch.int32
)
# Compute validity masks. # Compute validity masks.
sampled_flags = valid_sampled_tokens_count > 0 sampled_flags = valid_sampled_tokens_count > 0
...@@ -437,7 +439,7 @@ class NgramProposerGPU: ...@@ -437,7 +439,7 @@ class NgramProposerGPU:
) )
# Count valid tokens per request. # Count valid tokens per request.
valid_sampled_tokens_count = valid_mask.sum(dim=1) valid_sampled_tokens_count = valid_mask.sum(dim=1).to(torch.int32)
# Rightmost valid index per row. # Rightmost valid index per row.
last_valid_indices = valid_sampled_tokens_count - 1 last_valid_indices = valid_sampled_tokens_count - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment