"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "5ff5d94e77851f8ca11592bbb9aa414e65f4c353"
Commit a363d2c3 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复不开mtp精度异常问题

parent 7ff48a6c
...@@ -1188,7 +1188,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1188,7 +1188,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
"VLLM_USE_ZERO_MTP": "VLLM_USE_ZERO_MTP":
lambda: (os.getenv('VLLM_USE_ZERO_MTP', '0').lower() in lambda: (os.getenv('VLLM_USE_ZERO_MTP', '1').lower() in
("true", "1")), ("true", "1")),
# vllm will use 1-24... (not only 1 2 4 8 16 24) # vllm will use 1-24... (not only 1 2 4 8 16 24)
......
...@@ -3315,11 +3315,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3315,11 +3315,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
if not self.speculative_config: if not self.speculative_config:
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
else: else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record() self.spec_sampler_event.record()
mask = (sampled_token_ids == -1) mask = (sampled_token_ids == -1)
mask_int = mask.int() mask_int = mask.int()
...@@ -3338,7 +3338,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3338,7 +3338,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
) )
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids_cpu.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
self.spec_sampler_event.synchronize() self.spec_sampler_event.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment