Unverified Commit d9fa13ce authored by Kola's avatar Kola Committed by GitHub
Browse files

Add docstrings and types for MambaCache (#30023)

* Add docstrings and types for MambaCache

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* make fixup

* import copy in generation_whisper

* ruff

* Revert "make fixup"

This reverts commit c4fedd6f60e3b0f11974a11433bc130478829a5c.
parent b17b54d3
......@@ -61,7 +61,23 @@ from ..deprecated._archive_maps import MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST # no
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
"""
Arguments:
config: MambaConfig
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: MambaConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
......@@ -86,13 +102,13 @@ class MambaMixer(nn.Module):
and is why Mamba is called **selective** state spaces)
"""
def __init__(self, config, layer_idx):
def __init__(self, config: MambaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = config.time_step_rank
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment