Unverified Commit 4b53740d authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[MRV2] Fix for DS v3.2 (#38030)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 4e824d1c
...@@ -115,9 +115,12 @@ def _reshape_kv_cache( ...@@ -115,9 +115,12 @@ def _reshape_kv_cache(
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups: for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
kv_cache_spec = kv_cache_spec.kv_cache_specs[layer_name]
assert isinstance(kv_cache_spec, AttentionSpec)
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment