Unverified Commit 75531a6c authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)


Signed-off-by: default avatarDaniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarDaniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: default avatarBurkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 22341b99
......@@ -173,6 +173,7 @@ CYAN = '\033[1;36m'
RESET = '\033[0;0m'
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
......
......@@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
dtypes: tuple[torch.dtype]
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
@property
def page_size_bytes(self) -> int:
num_elements = sum(prod(shape) for shape in self.shapes)
page_size = num_elements * get_dtype_size(self.dtype)
page_size = sum(
prod(shape) * get_dtype_size(dtype)
for (shape, dtype) in zip(self.shapes, self.dtypes))
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded
......
......@@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
dtype = kv_cache_spec.dtype
num_element_per_page = (kv_cache_spec.page_size_bytes //
get_dtype_size(dtype))
state_tensors = []
storage_offset = 0
for shape in kv_cache_spec.shapes:
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset,
storage_offset=storage_offset_bytes // dtype_size,
)
state_tensors.append(tensor)
storage_offset += stride[0]
storage_offset_bytes += stride[0] * dtype_size
kv_caches[layer_name] = state_tensors
else:
......@@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment