Unverified Commit d1c956dc authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

Gemma3n (Text-only) (#20134)


Signed-off-by: default avatarrshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatarRoger Wang <hey@rogerw.me>
parent dec197e3
...@@ -336,6 +336,7 @@ Specified using `--task generate`. ...@@ -336,6 +336,7 @@ Specified using `--task generate`.
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | | `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
...@@ -392,6 +393,9 @@ Specified using `--task generate`. ...@@ -392,6 +393,9 @@ Specified using `--task generate`.
!!! note !!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
!!! note
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
### Pooling Models ### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models. See [this page](./pooling_models.md) for more information on how to use pooling models.
......
...@@ -164,6 +164,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -164,6 +164,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"), "Glm4ForCausalLM": _HfExamplesInfo("THUDM/GLM-4-9B-0414"),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
......
...@@ -135,6 +135,57 @@ class MulAndSilu(CustomOp): ...@@ -135,6 +135,57 @@ class MulAndSilu(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
"""An activation function for GeluAndMulSparse.
This activation function is used in Gemma3n. It computes:
up_proj = self.up_proj(x)
gate_proj = self.gate_proj(x)
gate_proj = self._gaussian_topk(gate_proj) # sparsity
activations = self.act_fn(gate_proj) # gelu
down_proj = self.down_proj(activations * up_proj)
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, activation_sparsity: float, approximate: str = "none"):
super().__init__()
# Gelu.
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
# Sparsity.
if activation_sparsity == 0.0:
raise ValueError(
"activation_sparsity is 0.0. Please use GeluAndMul.")
target_sparsity_tensor = torch.tensor(activation_sparsity,
dtype=torch.float32)
normal_dist = torch.distributions.normal.Normal(0, 1)
self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
"""Get % sparse percentile of the Gaussian distribution."""
# NOTE(rob): for TP>1, we could all-gather to get the means/std.
# But we do not do this because in expectation they are the same
# and in practice the eval scores are good without gathering.
mean = torch.mean(x, dim=-1, keepdim=True)
std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
cutoff_x = mean + std * self.std_multiplier
return nn.functional.relu(x - cutoff_x)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
out = self._gaussian_topk(x[..., :d])
out = F.gelu(out, approximate=self.approximate)
return out * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
@CustomOp.register("gelu_and_mul") @CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
"""An activation function for GeGLU. """An activation function for GeGLU.
......
This diff is collapsed.
...@@ -58,6 +58,8 @@ _TEXT_GENERATION_MODELS = { ...@@ -58,6 +58,8 @@ _TEXT_GENERATION_MODELS = {
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
#TODO(ywang96): Support multimodal gemma3n
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment