Unverified Commit eca18691 authored by Dhia Eddine Rhaiem's avatar Dhia Eddine Rhaiem Committed by GitHub
Browse files

[MODEL] FalconH1 (#18406)


Signed-off-by: default avatardhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarIlyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: default avatarJingwei Zuo <jingwei.zuo@tii.ae>
parent 61acfc45
...@@ -392,6 +392,11 @@ Specified using `--task generate`. ...@@ -392,6 +392,11 @@ Specified using `--task generate`.
* `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. * `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `FalconH1ForCausalLM`
* Falcon-H1
* `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc.
* ✅︎
* ✅︎
- * `GemmaForCausalLM` - * `GemmaForCausalLM`
* Gemma * Gemma
* `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. * `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc.
......
...@@ -147,6 +147,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -147,6 +147,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
is_available_online=False,
min_transformers_version="4.52.2"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
......
...@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
@CustomOp.register("mixer2_gated_rms_norm") @CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp): class Mixer2RMSNormGated(CustomOp):
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): def __init__(self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
self.n_groups = full_hidden_size // self.group_size self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)}) {"weight_loader": sharded_weight_loader(0)})
assert self.full_hidden_size % self.tp_size== 0,\ else:
"Tensor parallel world size must divide hidden size." # Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native( def forward_native(
self, self,
...@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
# the input and then redundantly compute the RMSNorm. # the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32)) x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
if self.n_groups == 1: if self.n_groups == 1:
if self.tp_size > 1: if self.tp_size > 1:
...@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
global_sums = tensor_model_parallel_all_reduce(local_sums) global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance # Calculate the variance
count = self.tp_size * x.shape[-1] count = self.tp_size * x.shape[-1]
variance = (global_sums / count) variance = global_sums / count
else: else:
variance = x.pow(2).mean(-1, keepdim=True) variance = x.pow(2).mean(-1, keepdim=True)
...@@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -106,6 +118,9 @@ class Mixer2RMSNormGated(CustomOp):
gate: torch.Tensor, gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
if self.tp_size > 1 or self.n_groups != 1: if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate) return self.forward_native(x, gate)
...@@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader( ...@@ -182,13 +197,15 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well. # seem to handle slices well.
# https://github.com/python/mypy/issues/2410 # https://github.com/python/mypy/issues/2410
param.data[ param.data[
boundary:(boundary + take), # type: ignore[misc] boundary:(boundary + take),
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] ... # type: ignore[misc]
loaded_start_idx + take)] # type: ignore[misc] ] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries # move indexing boundaries
boundary += shard_size boundary += shard_size
loaded_boundary += (full_dim - extra) loaded_boundary += full_dim - extra
return loader return loader
...@@ -206,7 +223,8 @@ class MambaMixer2(CustomOp): ...@@ -206,7 +223,8 @@ class MambaMixer2(CustomOp):
**selective** state spaces) **selective** state spaces)
""" """
def __init__(self, def __init__(
self,
hidden_size: int, hidden_size: int,
ssm_state_size: int, ssm_state_size: int,
conv_kernel_size: int, conv_kernel_size: int,
...@@ -217,8 +235,10 @@ class MambaMixer2(CustomOp): ...@@ -217,8 +235,10 @@ class MambaMixer2(CustomOp):
num_heads: int = 128, num_heads: int = 128,
head_dim: int = 64, head_dim: int = 64,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation="silu", activation: str = "silu",
quant_config: Optional[QuantizationConfig] = None): use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__() super().__init__()
# For TP, the sharding plan is as follows: # For TP, the sharding plan is as follows:
...@@ -238,17 +258,16 @@ class MambaMixer2(CustomOp): ...@@ -238,17 +258,16 @@ class MambaMixer2(CustomOp):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
assert num_heads % self.tp_size == 0, \ assert (num_heads % self.tp_size == 0
"Tensor parallel world size must divide num heads." ), "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
(
"If tensor parallel world size does not divide num_heads, " "If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1." "then num_groups must equal 1.")
)
assert self.tp_size == 1 or quant_config is None, \ assert (
"Tensor parallel currently not supported for quantized models." self.tp_size == 1 or quant_config is None
), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.activation = activation self.activation = activation
...@@ -265,8 +284,7 @@ class MambaMixer2(CustomOp): ...@@ -265,8 +284,7 @@ class MambaMixer2(CustomOp):
self.n_groups = n_groups + extra_groups_for_head_shards( self.n_groups = n_groups + extra_groups_for_head_shards(
n_groups, self.tp_size) n_groups, self.tp_size)
self.conv_dim = (intermediate_size + self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
2 * self.n_groups * ssm_state_size)
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
output_size=self.conv_dim, output_size=self.conv_dim,
...@@ -279,11 +297,12 @@ class MambaMixer2(CustomOp): ...@@ -279,11 +297,12 @@ class MambaMixer2(CustomOp):
# doesn't allow to override it # doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = ColumnParallelLinear(input_size=hidden_size, self.in_proj = ColumnParallelLinear(
output_size=intermediate_size + input_size=hidden_size,
self.conv_dim + self.num_heads, output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias, bias=use_bias,
quant_config=quant_config) quant_config=quant_config,
)
# - because in_proj is a concatenation of 3 weights, we # - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding # need to interleave them before sharding
...@@ -305,7 +324,8 @@ class MambaMixer2(CustomOp): ...@@ -305,7 +324,8 @@ class MambaMixer2(CustomOp):
# - ditto for the otther two weights below # - ditto for the otther two weights below
delattr(self.conv1d.bias, "weight_loader") delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.bias, { self.conv1d.bias,
{
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader( mamba_v2_sharded_weight_loader(
[ [
...@@ -316,18 +336,25 @@ class MambaMixer2(CustomOp): ...@@ -316,18 +336,25 @@ class MambaMixer2(CustomOp):
self.tp_size, self.tp_size,
tp_rank, tp_rank,
) )
}) },
)
delattr(self.conv1d.weight, "weight_loader") delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.weight, { self.conv1d.weight,
{
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader([ mamba_v2_sharded_weight_loader(
[
intermediate_settings, intermediate_settings,
group_shard_settings, group_shard_settings,
group_shard_settings, group_shard_settings,
], self.tp_size, tp_rank) ],
}) self.tp_size,
tp_rank,
)
},
)
if quant_config is None: if quant_config is None:
# - quant layers do not have a weight loader # - quant layers do not have a weight loader
...@@ -345,8 +372,10 @@ class MambaMixer2(CustomOp): ...@@ -345,8 +372,10 @@ class MambaMixer2(CustomOp):
head_setings, # for dt head_setings, # for dt
], ],
self.tp_size, self.tp_size,
tp_rank) tp_rank,
}) )
},
)
# - these are TPed by heads to reduce the size of the # - these are TPed by heads to reduce the size of the
# temporal shape # temporal shape
...@@ -357,6 +386,7 @@ class MambaMixer2(CustomOp): ...@@ -357,6 +386,7 @@ class MambaMixer2(CustomOp):
)) ))
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader( a_weight_loader = composed_weight_loader(
...@@ -365,18 +395,25 @@ class MambaMixer2(CustomOp): ...@@ -365,18 +395,25 @@ class MambaMixer2(CustomOp):
set_weight_attrs(self.dt_bias, set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)}) {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(intermediate_size, self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size, hidden_size,
bias=use_bias, bias=use_bias,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config) quant_config=quant_config,
)
self.norm = Mixer2RMSNormGated(intermediate_size, self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups, n_groups,
self.use_rms_norm,
eps=rms_norm_eps) eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor, def forward_native(
conv_state: torch.Tensor, ssm_state: torch.Tensor): self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
pass pass
def forward_cuda( def forward_cuda(
...@@ -384,6 +421,7 @@ class MambaMixer2(CustomOp): ...@@ -384,6 +421,7 @@ class MambaMixer2(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata, mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
): ):
# mamba2_metadata contains metadata necessary for the mamba2 triton # mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
...@@ -401,6 +439,10 @@ class MambaMixer2(CustomOp): ...@@ -401,6 +439,10 @@ class MambaMixer2(CustomOp):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split( gate, hidden_states_B_C, dt = torch.split(
projected_states, projected_states,
[ [
...@@ -561,6 +603,9 @@ class MambaMixer2(CustomOp): ...@@ -561,6 +603,9 @@ class MambaMixer2(CustomOp):
hidden_states = torch.vstack(ssd_output_list) hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP # 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate) hidden_states = self.norm(hidden_states, gate)
# 5. Final linear projection # 5. Final linear projection
......
This diff is collapsed.
...@@ -79,6 +79,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -79,6 +79,7 @@ _TEXT_GENERATION_MODELS = {
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
"FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"),
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment