Commit 4e06836d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[fix]修复test_spec_decode_worker中的错误

See merge request dcutoolkit/deeplearing/vllm!34
parents aad58f06 4d29e0a8
...@@ -60,6 +60,16 @@ class MockAttentionBackend(AttentionBackend): ...@@ -60,6 +60,16 @@ class MockAttentionBackend(AttentionBackend):
) -> None: ) -> None:
pass pass
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
def test_model_runner_input(): def test_model_runner_input():
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
......
...@@ -221,6 +221,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -221,6 +221,16 @@ class FlashAttentionBackend(AttentionBackend):
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadata): class FlashAttentionMetadata(AttentionMetadata):
......
...@@ -550,14 +550,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -550,14 +550,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states, execute_model_req.seq_group_metadata_list) hidden_states, execute_model_req.seq_group_metadata_list)
# Store logits from target model execution. # Store logits from target model execution.
logits = sampler_output.logits if self.tree_style_spec_decoding:
if logits is not None: logits = sampler_output.logits
if self.previous_logits is None: if logits is not None:
self.previous_logits = Logits( if self.previous_logits is None:
logits, execute_model_req.seq_group_metadata_list) self.previous_logits = Logits(
else: logits, execute_model_req.seq_group_metadata_list)
self.previous_logits.update( else:
logits, execute_model_req.seq_group_metadata_list) self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer: if not skip_proposer:
# We prepare the prefill hidden states here so that there no # We prepare the prefill hidden states here so that there no
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment