"vllm/multimodal/processing/dummy_inputs.py" did not exist on "e93ff6c8b92b7e7d067f535b52ac7c6304e6b316"
Commit a363d2c3 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复不开mtp精度异常问题

parent 7ff48a6c
......@@ -1188,7 +1188,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")),
"VLLM_USE_ZERO_MTP":
lambda: (os.getenv('VLLM_USE_ZERO_MTP', '0').lower() in
lambda: (os.getenv('VLLM_USE_ZERO_MTP', '1').lower() in
("true", "1")),
# vllm will use 1-24... (not only 1 2 4 8 16 24)
......
......@@ -3315,11 +3315,11 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
......@@ -3338,7 +3338,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
self.spec_sampler_event.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment