Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
......@@ -54,6 +54,16 @@ class MambaStateDtypeCalculator:
return (conv_state_dtype, temporal_state_dtype)
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
return (conv_state_dtype, )
class MambaStateShapeCalculator:
......@@ -122,6 +132,20 @@ class MambaStateShapeCalculator:
tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
if not use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return (conv_state_shape, )
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment