Unverified Commit 75531a6c authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)


Signed-off-by: default avatarDaniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarDaniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: default avatarBurkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 22341b99
...@@ -173,6 +173,7 @@ CYAN = '\033[1;36m' ...@@ -173,6 +173,7 @@ CYAN = '\033[1;36m'
RESET = '\033[0;0m' RESET = '\033[0;0m'
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
......
...@@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec): ...@@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
@dataclass(frozen=True) @dataclass(frozen=True)
class MambaSpec(KVCacheSpec): class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...] shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype dtypes: tuple[torch.dtype]
page_size_padded: Optional[int] = None page_size_padded: Optional[int] = None
mamba_type: str = "mamba2" mamba_type: str = "mamba2"
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
num_elements = sum(prod(shape) for shape in self.shapes) page_size = sum(
page_size = num_elements * get_dtype_size(self.dtype) prod(shape) * get_dtype_size(dtype)
for (shape, dtype) in zip(self.shapes, self.dtypes))
if self.page_size_padded is not None: if self.page_size_padded is not None:
assert self.page_size_padded >= page_size assert self.page_size_padded >= page_size
return self.page_size_padded return self.page_size_padded
......
...@@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
dtype = kv_cache_spec.dtype
num_element_per_page = (kv_cache_spec.page_size_bytes //
get_dtype_size(dtype))
state_tensors = [] state_tensors = []
storage_offset = 0 storage_offset_bytes = 0
for shape in kv_cache_spec.shapes: for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape) target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride() stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:]) target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided( tensor = torch.as_strided(
raw_tensor.view(dtype), raw_tensor.view(dtype),
size=target_shape, size=target_shape,
stride=target_stride, stride=target_stride,
storage_offset=storage_offset, storage_offset=storage_offset_bytes // dtype_size,
) )
state_tensors.append(tensor) state_tensors.append(tensor)
storage_offset += stride[0] storage_offset_bytes += stride[0] * dtype_size
kv_caches[layer_name] = state_tensors kv_caches[layer_name] = state_tensors
else: else:
...@@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for layer_name, mamba_module in mamba_layers.items(): for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec( kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(), shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype, dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len, block_size=max_model_len,
page_size_padded=page_size_padded, page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type) mamba_type=mamba_module.mamba_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment