Unverified Commit eca18691 authored by Dhia Eddine Rhaiem's avatar Dhia Eddine Rhaiem Committed by GitHub
Browse files

[MODEL] FalconH1 (#18406)


Signed-off-by: default avatardhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarIlyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: default avatarJingwei Zuo <jingwei.zuo@tii.ae>
parent 61acfc45
...@@ -392,6 +392,11 @@ Specified using `--task generate`. ...@@ -392,6 +392,11 @@ Specified using `--task generate`.
* `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. * `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `FalconH1ForCausalLM`
* Falcon-H1
* `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc.
* ✅︎
* ✅︎
- * `GemmaForCausalLM` - * `GemmaForCausalLM`
* Gemma * Gemma
* `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. * `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc.
......
...@@ -147,6 +147,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -147,6 +147,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
is_available_online=False,
min_transformers_version="4.52.2"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
......
...@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
@CustomOp.register("mixer2_gated_rms_norm") @CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp): class Mixer2RMSNormGated(CustomOp):
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): def __init__(self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
self.n_groups = full_hidden_size // self.group_size self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) self.use_rms_norm = use_rms_norm
set_weight_attrs(self.weight, if self.use_rms_norm:
{"weight_loader": sharded_weight_loader(0)}) # Register norm weight only if we're actually applying RMSNorm
assert self.full_hidden_size % self.tp_size== 0,\ self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
"Tensor parallel world size must divide hidden size." set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native( def forward_native(
self, self,
...@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
# the input and then redundantly compute the RMSNorm. # the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32)) x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
if self.n_groups == 1: if self.n_groups == 1:
if self.tp_size > 1: if self.tp_size > 1:
...@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
global_sums = tensor_model_parallel_all_reduce(local_sums) global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance # Calculate the variance
count = self.tp_size * x.shape[-1] count = self.tp_size * x.shape[-1]
variance = (global_sums / count) variance = global_sums / count
else: else:
variance = x.pow(2).mean(-1, keepdim=True) variance = x.pow(2).mean(-1, keepdim=True)
...@@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp):
gate: torch.Tensor, gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
if self.tp_size > 1 or self.n_groups != 1: if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate) return self.forward_native(x, gate)
...@@ -124,7 +139,7 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -124,7 +139,7 @@ class Mixer2RMSNormGated(CustomOp):
def extra_groups_for_head_shards(ngroups: int, tp_size: int): def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for """Compute the increase in group numbers to account for
replication in order to accompany the head shards.""" replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero # in the case ngoups % tp_size == 0, this will be zero
...@@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader( ...@@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well. # seem to handle slices well.
# https://github.com/python/mypy/issues/2410 # https://github.com/python/mypy/issues/2410
param.data[ param.data[
boundary:(boundary + take), # type: ignore[misc] boundary:(boundary + take),
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] ... # type: ignore[misc]
loaded_start_idx + take)] # type: ignore[misc] ] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries # move indexing boundaries
boundary += shard_size boundary += shard_size
loaded_boundary += (full_dim - extra) loaded_boundary += full_dim - extra
return loader return loader
...@@ -206,19 +223,22 @@ class MambaMixer2(CustomOp): ...@@ -206,19 +223,22 @@ class MambaMixer2(CustomOp):
**selective** state spaces) **selective** state spaces)
""" """
def __init__(self, def __init__(
hidden_size: int, self,
ssm_state_size: int, hidden_size: int,
conv_kernel_size: int, ssm_state_size: int,
intermediate_size: int, conv_kernel_size: int,
use_conv_bias: bool, intermediate_size: int,
use_bias: bool, use_conv_bias: bool,
n_groups: int = 1, use_bias: bool,
num_heads: int = 128, n_groups: int = 1,
head_dim: int = 64, num_heads: int = 128,
rms_norm_eps: float = 1e-5, head_dim: int = 64,
activation="silu", rms_norm_eps: float = 1e-5,
quant_config: Optional[QuantizationConfig] = None): activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__() super().__init__()
# For TP, the sharding plan is as follows: # For TP, the sharding plan is as follows:
...@@ -238,17 +258,16 @@ class MambaMixer2(CustomOp): ...@@ -238,17 +258,16 @@ class MambaMixer2(CustomOp):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
assert num_heads % self.tp_size == 0, \ assert (num_heads % self.tp_size == 0
"Tensor parallel world size must divide num heads." ), "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
( "If tensor parallel world size does not divide num_heads, "
"If tensor parallel world size does not divide num_heads, " "then num_groups must equal 1.")
"then num_groups must equal 1."
)
assert self.tp_size == 1 or quant_config is None, \ assert (
"Tensor parallel currently not supported for quantized models." self.tp_size == 1 or quant_config is None
), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.activation = activation self.activation = activation
...@@ -265,8 +284,7 @@ class MambaMixer2(CustomOp): ...@@ -265,8 +284,7 @@ class MambaMixer2(CustomOp):
self.n_groups = n_groups + extra_groups_for_head_shards( self.n_groups = n_groups + extra_groups_for_head_shards(
n_groups, self.tp_size) n_groups, self.tp_size)
self.conv_dim = (intermediate_size + self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
2 * self.n_groups * ssm_state_size)
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
output_size=self.conv_dim, output_size=self.conv_dim,
...@@ -279,11 +297,12 @@ class MambaMixer2(CustomOp): ...@@ -279,11 +297,12 @@ class MambaMixer2(CustomOp):
# doesn't allow to override it # doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = ColumnParallelLinear(input_size=hidden_size, self.in_proj = ColumnParallelLinear(
output_size=intermediate_size + input_size=hidden_size,
self.conv_dim + self.num_heads, output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias, bias=use_bias,
quant_config=quant_config) quant_config=quant_config,
)
# - because in_proj is a concatenation of 3 weights, we # - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding # need to interleave them before sharding
...@@ -305,7 +324,8 @@ class MambaMixer2(CustomOp): ...@@ -305,7 +324,8 @@ class MambaMixer2(CustomOp):
# - ditto for the otther two weights below # - ditto for the otther two weights below
delattr(self.conv1d.bias, "weight_loader") delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.bias, { self.conv1d.bias,
{
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader( mamba_v2_sharded_weight_loader(
[ [
...@@ -316,18 +336,25 @@ class MambaMixer2(CustomOp): ...@@ -316,18 +336,25 @@ class MambaMixer2(CustomOp):
self.tp_size, self.tp_size,
tp_rank, tp_rank,
) )
}) },
)
delattr(self.conv1d.weight, "weight_loader") delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.weight, { self.conv1d.weight,
{
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader([ mamba_v2_sharded_weight_loader(
intermediate_settings, [
group_shard_settings, intermediate_settings,
group_shard_settings, group_shard_settings,
], self.tp_size, tp_rank) group_shard_settings,
}) ],
self.tp_size,
tp_rank,
)
},
)
if quant_config is None: if quant_config is None:
# - quant layers do not have a weight loader # - quant layers do not have a weight loader
...@@ -345,8 +372,10 @@ class MambaMixer2(CustomOp): ...@@ -345,8 +372,10 @@ class MambaMixer2(CustomOp):
head_setings, # for dt head_setings, # for dt
], ],
self.tp_size, self.tp_size,
tp_rank) tp_rank,
}) )
},
)
# - these are TPed by heads to reduce the size of the # - these are TPed by heads to reduce the size of the
# temporal shape # temporal shape
...@@ -357,6 +386,7 @@ class MambaMixer2(CustomOp): ...@@ -357,6 +386,7 @@ class MambaMixer2(CustomOp):
)) ))
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader( a_weight_loader = composed_weight_loader(
...@@ -365,18 +395,25 @@ class MambaMixer2(CustomOp): ...@@ -365,18 +395,25 @@ class MambaMixer2(CustomOp):
set_weight_attrs(self.dt_bias, set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)}) {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(intermediate_size, self.out_proj = RowParallelLinear(
hidden_size, intermediate_size,
bias=use_bias, hidden_size,
input_is_parallel=True, bias=use_bias,
quant_config=quant_config) input_is_parallel=True,
quant_config=quant_config,
)
self.norm = Mixer2RMSNormGated(intermediate_size, self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups, n_groups,
self.use_rms_norm,
eps=rms_norm_eps) eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor, def forward_native(
conv_state: torch.Tensor, ssm_state: torch.Tensor): self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
pass pass
def forward_cuda( def forward_cuda(
...@@ -384,6 +421,7 @@ class MambaMixer2(CustomOp): ...@@ -384,6 +421,7 @@ class MambaMixer2(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata, mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
): ):
# mamba2_metadata contains metadata necessary for the mamba2 triton # mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
...@@ -401,6 +439,10 @@ class MambaMixer2(CustomOp): ...@@ -401,6 +439,10 @@ class MambaMixer2(CustomOp):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split( gate, hidden_states_B_C, dt = torch.split(
projected_states, projected_states,
[ [
...@@ -561,6 +603,9 @@ class MambaMixer2(CustomOp): ...@@ -561,6 +603,9 @@ class MambaMixer2(CustomOp):
hidden_states = torch.vstack(ssd_output_list) hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP # 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate) hidden_states = self.norm(hidden_states, gate)
# 5. Final linear projection # 5. Final linear projection
......
This diff is collapsed.
...@@ -79,6 +79,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -79,6 +79,7 @@ _TEXT_GENERATION_MODELS = {
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment