Unverified Commit 76b3b20f authored by Kola's avatar Kola Committed by GitHub
Browse files

Update Mamba types and pass through use_cache attr to MambaModel (#29605)



* Update docstring for RMSNorm

* Update cache_params object to correct MambaCache type

* Update docstrings and type info

* Pass through use_cache

* ruff

* Reformat with 119 char limit per line (thanks Arthur)

* Pass through use_cache specifically to the backbone rather than all keyword arguments

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tab

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 776c9d3a
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -59,6 +59,24 @@ _CONFIG_FOR_DOC = "MambaConfig" ...@@ -59,6 +59,24 @@ _CONFIG_FOR_DOC = "MambaConfig"
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel
self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
class MambaMixer(nn.Module): class MambaMixer(nn.Module):
""" """
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
...@@ -112,7 +130,7 @@ class MambaMixer(nn.Module): ...@@ -112,7 +130,7 @@ class MambaMixer(nn.Module):
" https://github.com/Dao-AILab/causal-conv1d" " https://github.com/Dao-AILab/causal-conv1d"
) )
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None): def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2) projected_states = self.in_proj(hidden_states).transpose(1, 2)
...@@ -202,7 +220,7 @@ class MambaMixer(nn.Module): ...@@ -202,7 +220,7 @@ class MambaMixer(nn.Module):
return contextualized_states return contextualized_states
# fmt: off # fmt: off
def slow_forward(self, input_states, cache_params=None): def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None):
batch_size, seq_len, _ = input_states.shape batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype dtype = input_states.dtype
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
...@@ -268,34 +286,16 @@ class MambaMixer(nn.Module): ...@@ -268,34 +286,16 @@ class MambaMixer(nn.Module):
return contextualized_states return contextualized_states
# fmt: on # fmt: on
def forward(self, hidden_states, cache_params=None): def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params) return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params) return self.slow_forward(hidden_states, cache_params)
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel
self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
class MambaRMSNorm(nn.Module): class MambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
LlamaRMSNorm is equivalent to T5LayerNorm MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
""" """
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
...@@ -318,7 +318,7 @@ class MambaBlock(nn.Module): ...@@ -318,7 +318,7 @@ class MambaBlock(nn.Module):
self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = MambaMixer(config, layer_idx=layer_idx) self.mixer = MambaMixer(config, layer_idx=layer_idx)
def forward(self, hidden_states, cache_params=None): def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32: if self.residual_in_fp32:
...@@ -396,11 +396,11 @@ class MambaOutput(ModelOutput): ...@@ -396,11 +396,11 @@ class MambaOutput(ModelOutput):
Args: Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`. avoid providing the old `input_ids`.
Includes both the State space model states weights after the selective scan, and the Convolutional states Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
...@@ -408,8 +408,8 @@ class MambaOutput(ModelOutput): ...@@ -408,8 +408,8 @@ class MambaOutput(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
""" """
last_hidden_state: torch.FloatTensor = None last_hidden_state: Optional[torch.FloatTensor] = None
cache_params: Optional[List[torch.FloatTensor]] = None cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
...@@ -423,9 +423,11 @@ class MambaCausalLMOutput(ModelOutput): ...@@ -423,9 +423,11 @@ class MambaCausalLMOutput(ModelOutput):
Language modeling loss (for next-token prediction). Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`. avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
...@@ -434,8 +436,8 @@ class MambaCausalLMOutput(ModelOutput): ...@@ -434,8 +436,8 @@ class MambaCausalLMOutput(ModelOutput):
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: Optional[torch.FloatTensor] = None
cache_params: Optional[List[torch.FloatTensor]] = None cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
...@@ -516,7 +518,7 @@ class MambaModel(MambaPreTrainedModel): ...@@ -516,7 +518,7 @@ class MambaModel(MambaPreTrainedModel):
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[List[torch.FloatTensor]] = None, cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
...@@ -609,7 +611,7 @@ class MambaForCausalLM(MambaPreTrainedModel): ...@@ -609,7 +611,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
return model_kwargs return model_kwargs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
): ):
# only last token for inputs_ids if the state is passed along. # only last token for inputs_ids if the state is passed along.
if cache_params is not None: if cache_params is not None:
...@@ -633,10 +635,11 @@ class MambaForCausalLM(MambaPreTrainedModel): ...@@ -633,10 +635,11 @@ class MambaForCausalLM(MambaPreTrainedModel):
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[torch.FloatTensor] = None, cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation **kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]: ) -> Union[Tuple, MambaCausalLMOutput]:
r""" r"""
...@@ -653,6 +656,7 @@ class MambaForCausalLM(MambaPreTrainedModel): ...@@ -653,6 +656,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
use_cache=use_cache,
) )
hidden_states = mamba_outputs[0] hidden_states = mamba_outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment