Unverified Commit 6a03942d authored by Ao Tang's avatar Ao Tang Committed by GitHub
Browse files

Add Nemotron HF Support (#31699)

* Add nemotron support

* fix inference

* add unit test

* add layernorm1p as a class to avoid meta device mismatch

* test fixed

* Add copied_from statements

* remove pretraining_tp args

* remove nemotronlayernorm

* force LN computation done in FP32

* remove nemotrontokenizer and use llamatokenizer

* license update

* add option for kv_channels for minitron8b

* remove assert

* o_proj fixed

* o_proj reshape

* add gated_proj option

* typo

* remove todos

* fix broken test after merging latest main

* remove nezha/nat after meging main

* chnage default config to 15b model

* add nemo conversion script

* rename conversion script

* remove gate_proj option

* pr comment resolved

* fix unit test

* rename kv_channels to head_dim

* resolve PR issue

* add nemotron md

* fix broken tests

* refactor rope for nemotron

* test fix

* remove linearscaling

* whitespace and import

* fix some copied-from

* code style fix

* reformatted

* add position_embedding to nemotronattention

* rope refactor to only use config, copied-from fix

* format

* Run make fix-copies

* nemotron md with autodoc

* doc  fix

* fix order

* pass check_config_docstrings.py

* fix config_attributes

* remove all llama BC related code

* Use PreTrainedTokenizerFast

* ruff check examples

* conversion script update

* add nemotron to toctree
parent 36fd35e1
...@@ -468,6 +468,8 @@ ...@@ -468,6 +468,8 @@
title: MT5 title: MT5
- local: model_doc/mvp - local: model_doc/mvp
title: MVP title: MVP
- local: model_doc/nemotron
title: Nemotron
- local: model_doc/nezha - local: model_doc/nezha
title: NEZHA title: NEZHA
- local: model_doc/nllb - local: model_doc/nllb
......
...@@ -222,6 +222,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -222,6 +222,7 @@ Flax), PyTorch, and/or TensorFlow.
| [MusicGen Melody](model_doc/musicgen_melody) | ✅ | ❌ | ❌ | | [MusicGen Melody](model_doc/musicgen_melody) | ✅ | ❌ | ❌ |
| [MVP](model_doc/mvp) | ✅ | ❌ | ❌ | | [MVP](model_doc/mvp) | ✅ | ❌ | ❌ |
| [NAT](model_doc/nat) | ✅ | ❌ | ❌ | | [NAT](model_doc/nat) | ✅ | ❌ | ❌ |
| [Nemotron](model_doc/nemotron) | ✅ | ❌ | ❌ |
| [Nezha](model_doc/nezha) | ✅ | ❌ | ❌ | | [Nezha](model_doc/nezha) | ✅ | ❌ | ❌ |
| [NLLB](model_doc/nllb) | ✅ | ❌ | ❌ | | [NLLB](model_doc/nllb) | ✅ | ❌ | ❌ |
| [NLLB-MOE](model_doc/nllb-moe) | ✅ | ❌ | ❌ | | [NLLB-MOE](model_doc/nllb-moe) | ✅ | ❌ | ❌ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Nemotron
## Nemotron
### License
The use of this model is governed by the [NVIDIA AI Foundation Models Community License Agreement](https://developer.nvidia.com/downloads/nv-ai-foundation-models-license).
### Description
Nemotron-4 is a family of enterprise ready generative text models compatible with [NVIDIA NeMo Framework](https://www.nvidia.com/en-us/ai-data-science/generative-ai/nemo-framework/).
NVIDIA NeMo is an end-to-end, cloud-native platform to build, customize, and deploy generative AI models anywhere. It includes training and inferencing frameworks, guardrailing toolkits, data curation tools, and pretrained models, offering enterprises an easy, cost-effective, and fast way to adopt generative AI. To get access to NeMo Framework, please sign up at [this link](https://developer.nvidia.com/nemo-framework/join).
### References
[Announcement Blog](https://developer.nvidia.com/blog/nvidia-ai-foundation-models-build-custom-enterprise-chatbots-and-co-pilots-with-production-ready-llms/)
### Model Architecture
**Architecture Type:** Transformer
**Network Architecture:** Transformer Decoder (auto-regressive language model).
## Minitron
### Minitron 4B Base
Minitron is a family of small language models (SLMs) obtained by pruning NVIDIA's [Nemotron-4 15B](https://arxiv.org/abs/2402.16819) model. We prune model embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.
Deriving the Minitron 8B and 4B models from the base 15B model using our approach requires up to **40x fewer training tokens** per model compared to training from scratch; this results in **compute cost savings of 1.8x** for training the full model family (15B, 8B, and 4B). Minitron models exhibit up to a 16% improvement in MMLU scores compared to training from scratch, perform comparably to other community models such as Mistral 7B, Gemma 7B and Llama-3 8B, and outperform state-of-the-art compression techniques from the literature. Please refer to our [arXiv paper](https://arxiv.org/abs/2407.14679) for more details.
Minitron models are for research and development only.
### HuggingFace Quickstart
The following code provides an example of how to load the Minitron-4B model and use it to perform text generation.
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the tokenizer and model
model_path = 'nvidia/Minitron-4B-Base'
tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda'
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
# Prepare the input text
prompt = 'Complete the paragraph: our solar system is'
inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
# Generate the output
outputs = model.generate(inputs, max_length=20)
# Decode and print the output
output_text = tokenizer.decode(outputs[0])
print(output_text)
```
### License
Minitron is released under the [NVIDIA Open Model License Agreement](https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf).
### Evaluation Results
*5-shot performance.* Language Understanding evaluated using [Massive Multitask Language Understanding](https://arxiv.org/abs/2009.03300):
| Average |
| :---- |
| 58.6 |
*Zero-shot performance.* Evaluated using select datasets from the [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) with additions:
| HellaSwag | Winogrande | GSM8K| ARC-C | XLSum |
| :------------- | :------------- | :------------- | :------------- | :------------- |
| 75.0 | 74.0 | 24.1 | 50.9 | 29.5
*Code generation performance*. Evaluated using [HumanEval](https://github.com/openai/human-eval):
| p@1, 0-Shot |
| :------------- |
| 23.3 |
Please refer to our [paper](https://arxiv.org/abs/2407.14679) for the full set of results.
### Citation
If you find our work helpful, please consider citing our paper:
```
@article{minitron2024,
title={Compact Language Models via Pruning and Knowledge Distillation},
author={Saurav Muralidharan and Sharath Turuvekere Sreenivas and Raviraj Joshi and Marcin Chochowski and Mostofa Patwary and Mohammad Shoeybi and Bryan Catanzaro and Jan Kautz and Pavlo Molchanov},
journal={arXiv preprint arXiv:2407.14679},
year={2024},
url={https://arxiv.org/abs/2407.14679},
}
```
## NemotronConfig
[[autodoc]] NemotronConfig
## NemotronModel
[[autodoc]] NemotronModel
- forward
## NemotronForCausalLM
[[autodoc]] NemotronForCausalLM
- forward
## NemotronForSequenceClassification
[[autodoc]] NemotronForSequenceClassification
- forward
## NemotronForQuestionAnswering
[[autodoc]] NemotronForQuestionAnswering
- forward
## NemotronForTokenClassification
[[autodoc]] NemotronForTokenClassification
- forward
\ No newline at end of file
...@@ -67,6 +67,7 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -67,6 +67,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
...@@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel) * [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel) * [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) * [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
......
...@@ -592,6 +592,7 @@ _import_structure = { ...@@ -592,6 +592,7 @@ _import_structure = {
"MusicgenMelodyDecoderConfig", "MusicgenMelodyDecoderConfig",
], ],
"models.mvp": ["MvpConfig", "MvpTokenizer"], "models.mvp": ["MvpConfig", "MvpTokenizer"],
"models.nemotron": ["NemotronConfig"],
"models.nllb": [], "models.nllb": [],
"models.nllb_moe": ["NllbMoeConfig"], "models.nllb_moe": ["NllbMoeConfig"],
"models.nougat": ["NougatProcessor"], "models.nougat": ["NougatProcessor"],
...@@ -2742,6 +2743,16 @@ else: ...@@ -2742,6 +2743,16 @@ else:
"MvpPreTrainedModel", "MvpPreTrainedModel",
] ]
) )
_import_structure["models.nemotron"].extend(
[
"NemotronForCausalLM",
"NemotronForQuestionAnswering",
"NemotronForSequenceClassification",
"NemotronForTokenClassification",
"NemotronModel",
"NemotronPreTrainedModel",
]
)
_import_structure["models.nllb_moe"].extend( _import_structure["models.nllb_moe"].extend(
[ [
"NllbMoeForConditionalGeneration", "NllbMoeForConditionalGeneration",
...@@ -5286,6 +5297,7 @@ if TYPE_CHECKING: ...@@ -5286,6 +5297,7 @@ if TYPE_CHECKING:
MusicgenMelodyDecoderConfig, MusicgenMelodyDecoderConfig,
) )
from .models.mvp import MvpConfig, MvpTokenizer from .models.mvp import MvpConfig, MvpTokenizer
from .models.nemotron import NemotronConfig
from .models.nllb_moe import NllbMoeConfig from .models.nllb_moe import NllbMoeConfig
from .models.nougat import NougatProcessor from .models.nougat import NougatProcessor
from .models.nystromformer import ( from .models.nystromformer import (
...@@ -7187,6 +7199,14 @@ if TYPE_CHECKING: ...@@ -7187,6 +7199,14 @@ if TYPE_CHECKING:
MvpModel, MvpModel,
MvpPreTrainedModel, MvpPreTrainedModel,
) )
from .models.nemotron import (
NemotronForCausalLM,
NemotronForQuestionAnswering,
NemotronForSequenceClassification,
NemotronForTokenClassification,
NemotronModel,
NemotronPreTrainedModel,
)
from .models.nllb_moe import ( from .models.nllb_moe import (
NllbMoeForConditionalGeneration, NllbMoeForConditionalGeneration,
NllbMoeModel, NllbMoeModel,
......
...@@ -159,6 +159,7 @@ from . import ( ...@@ -159,6 +159,7 @@ from . import (
musicgen, musicgen,
musicgen_melody, musicgen_melody,
mvp, mvp,
nemotron,
nllb, nllb,
nllb_moe, nllb_moe,
nougat, nougat,
......
...@@ -177,6 +177,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -177,6 +177,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("musicgen_melody", "MusicgenMelodyConfig"), ("musicgen_melody", "MusicgenMelodyConfig"),
("mvp", "MvpConfig"), ("mvp", "MvpConfig"),
("nat", "NatConfig"), ("nat", "NatConfig"),
("nemotron", "NemotronConfig"),
("nezha", "NezhaConfig"), ("nezha", "NezhaConfig"),
("nllb-moe", "NllbMoeConfig"), ("nllb-moe", "NllbMoeConfig"),
("nougat", "VisionEncoderDecoderConfig"), ("nougat", "VisionEncoderDecoderConfig"),
...@@ -469,6 +470,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -469,6 +470,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("musicgen_melody", "MusicGen Melody"), ("musicgen_melody", "MusicGen Melody"),
("mvp", "MVP"), ("mvp", "MVP"),
("nat", "NAT"), ("nat", "NAT"),
("nemotron", "Nemotron"),
("nezha", "Nezha"), ("nezha", "Nezha"),
("nllb", "NLLB"), ("nllb", "NLLB"),
("nllb-moe", "NLLB-MOE"), ("nllb-moe", "NLLB-MOE"),
......
...@@ -169,6 +169,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -169,6 +169,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("musicgen_melody", "MusicgenMelodyModel"), ("musicgen_melody", "MusicgenMelodyModel"),
("mvp", "MvpModel"), ("mvp", "MvpModel"),
("nat", "NatModel"), ("nat", "NatModel"),
("nemotron", "NemotronModel"),
("nezha", "NezhaModel"), ("nezha", "NezhaModel"),
("nllb-moe", "NllbMoeModel"), ("nllb-moe", "NllbMoeModel"),
("nystromformer", "NystromformerModel"), ("nystromformer", "NystromformerModel"),
...@@ -481,6 +482,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -481,6 +482,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("musicgen", "MusicgenForCausalLM"), ("musicgen", "MusicgenForCausalLM"),
("musicgen_melody", "MusicgenMelodyForCausalLM"), ("musicgen_melody", "MusicgenMelodyForCausalLM"),
("mvp", "MvpForCausalLM"), ("mvp", "MvpForCausalLM"),
("nemotron", "NemotronForCausalLM"),
("olmo", "OlmoForCausalLM"), ("olmo", "OlmoForCausalLM"),
("open-llama", "OpenLlamaForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"), ("openai-gpt", "OpenAIGPTLMHeadModel"),
...@@ -902,6 +904,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -902,6 +904,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mra", "MraForSequenceClassification"), ("mra", "MraForSequenceClassification"),
("mt5", "MT5ForSequenceClassification"), ("mt5", "MT5ForSequenceClassification"),
("mvp", "MvpForSequenceClassification"), ("mvp", "MvpForSequenceClassification"),
("nemotron", "NemotronForSequenceClassification"),
("nezha", "NezhaForSequenceClassification"), ("nezha", "NezhaForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"),
("open-llama", "OpenLlamaForSequenceClassification"), ("open-llama", "OpenLlamaForSequenceClassification"),
...@@ -983,6 +986,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -983,6 +986,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("mra", "MraForQuestionAnswering"), ("mra", "MraForQuestionAnswering"),
("mt5", "MT5ForQuestionAnswering"), ("mt5", "MT5ForQuestionAnswering"),
("mvp", "MvpForQuestionAnswering"), ("mvp", "MvpForQuestionAnswering"),
("nemotron", "NemotronForQuestionAnswering"),
("nezha", "NezhaForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"),
("nystromformer", "NystromformerForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"),
("opt", "OPTForQuestionAnswering"), ("opt", "OPTForQuestionAnswering"),
...@@ -1078,6 +1082,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -1078,6 +1082,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mpt", "MptForTokenClassification"), ("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"), ("mra", "MraForTokenClassification"),
("mt5", "MT5ForTokenClassification"), ("mt5", "MT5ForTokenClassification"),
("nemotron", "NemotronForTokenClassification"),
("nezha", "NezhaForTokenClassification"), ("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"),
("persimmon", "PersimmonForTokenClassification"), ("persimmon", "PersimmonForTokenClassification"),
......
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_torch_available,
)
_import_structure = {
"configuration_nemotron": ["NemotronConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_nemotron"] = [
"NemotronForQuestionAnswering",
"NemotronForCausalLM",
"NemotronModel",
"NemotronPreTrainedModel",
"NemotronForSequenceClassification",
"NemotronForTokenClassification",
]
if TYPE_CHECKING:
from .configuration_nemotron import NemotronConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_nemotron import (
NemotronForCausalLM,
NemotronForQuestionAnswering,
NemotronForSequenceClassification,
NemotronForTokenClassification,
NemotronModel,
NemotronPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Nemotron model configuration"""
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
logger = logging.get_logger(__name__)
class NemotronConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`NemotronModel`]. It is used to instantiate an Nemotron
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Nemotron-8B.
e.g. [nvidia/nemotron-3-8b-base-4k-hf](https://huggingface.co/nvidia/nemotron-3-8b-base-4k-hf).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer decoder.
head_dim (`int`, *optional*):
Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.0134):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 3):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj and down_proj layers in the MLP layers.
```python
>>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "nemotron"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=6144,
intermediate_size=24576,
num_hidden_layers=32,
num_attention_heads=48,
head_dim=None,
num_key_value_heads=None,
hidden_act="relu2",
max_position_embeddings=4096,
initializer_range=0.0134,
norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=2,
eos_token_id=3,
tie_word_embeddings=False,
rope_theta=10000.0,
partial_rotary_factor=0.5,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor
rope_config_validation(self)
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
from argparse import ArgumentParser
from collections import OrderedDict
import torch
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils import logging
from pytorch_lightning import Trainer
from transformers import LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import LlamaConverter
"""
Script to convert a nemotron checkpoint in nemo (mcore path) into a HuggingFace checkpoint.
This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder.
1) Generate only HF weights from a nemo file:
python convert_nemotron_nemo_to_hf.py \
--input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
--output_path /path/to/pytorch_model.bin
2) Generate the full HF model folder
python convert_nemotron_nemo_to_hf.py \
--input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
--hf_input_path /path/to/input_hf_folder \
--hf_output_path /path/to/output_hf_folder \
Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Nemotron4 340b).
However this option makes the conversion script significantly slower.
"""
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--input_name_or_path",
type=str,
default=None,
required=True,
help="Path to .nemo file or extracted folder",
)
parser.add_argument("--output_path", type=str, default=None, required=False, help="Path to HF .bin file")
parser.add_argument(
"--hf_input_path",
type=str,
default=None,
help="A HF model path, " "e.g. a folder containing https://huggingface.co/nvidia/Minitron-8B-Base",
)
parser.add_argument(
"--hf_output_path",
type=str,
default=None,
help="Output HF model path, " "with the same format as above but user's own weights",
)
parser.add_argument(
"--precision",
type=str,
default=None,
help="Precision of output weights."
"Defaults to precision of the input nemo weights (model.cfg.trainer.precision)",
)
parser.add_argument(
"--cpu-only",
action="store_true",
help="Load model in cpu only. Useful if the model cannot fit in GPU memory, "
"but this option makes the conversion script significantly slower.",
)
args = parser.parse_args()
return args
def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, hf_url="nvidia/Minitron-8B-Base"):
"""
Convert NeMo config to HF config
"""
NEMO_ACT2HF = {
"squared-relu": "relu2",
"fast-swiglu": "silu",
}
DTYPE2HF = {
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
}
hf_config = {
"_name_or_path": hf_url,
"architectures": ["NemotronForCausalLM"],
"bos_token_id": tokenizer.bos_id,
"eos_token_id": tokenizer.eos_id,
"hidden_act": NEMO_ACT2HF[nemo_config.activation],
"hidden_size": nemo_config.hidden_size,
"initializer_range": nemo_config.init_method_std,
"intermediate_size": nemo_config.ffn_hidden_size,
"max_position_embeddings": nemo_config.max_position_embeddings,
"model_type": "nemotron",
"num_attention_heads": nemo_config.num_attention_heads,
"num_hidden_layers": nemo_config.num_layers,
"num_key_value_heads": nemo_config.get("num_query_groups", nemo_config.num_attention_heads),
"norm_eps": nemo_config.layernorm_epsilon,
"rope_theta": nemo_config.get("rotary_base", 10000),
"partial_rotary_factor": nemo_config.get("rotary_percentage", 1.0),
"tie_word_embeddings": False,
"torch_dtype": DTYPE2HF[dtype],
"transformers_version": "4.32.0.dev0", # TODO
"use_cache": True,
"vocab_size": vocab_size,
}
if nemo_config.kv_channels is not None:
hf_config["kv_channels"] = nemo_config.kv_channels
json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2)
def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None:
"""
Convert NeMo weights to HF weights
"""
dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy())
model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True)
model_config.tensor_model_parallel_size = 1
model_config.pipeline_model_parallel_size = 1
model_config.sequence_parallel = False
model_config.transformer_engine = True
if cpu_only:
map_location = torch.device("cpu")
model_config.use_cpu_initialization = True
model_config.dist_ckpt_load_on_device = False
else:
map_location = None
if cpu_only:
logging.info("******** Loading model on CPU. This will take a significant amount of time.")
model = MegatronGPTModel.restore_from(
input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
)
vocab_size = model.padded_vocab_size
if precision is None:
precision = model.cfg.precision
if precision in [32, "32"]:
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
logging.warning(f"Precision string {precision} is not recognized, falling back to fp32")
dtype = torch.float32 # fallback
logging.info(f"Using precision {dtype}")
def param_to_weights(param):
return param.to(dtype)
checkpoint = OrderedDict()
hidden_size = model.cfg.hidden_size
head_num = model.cfg.num_attention_heads
num_layers = model.cfg.num_layers
ffn_hidden_size = model.cfg.ffn_hidden_size
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B
if num_query_groups is None:
num_query_groups = head_num
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups
# Embedding
embed_weight = model.state_dict()["model.embedding.word_embeddings.weight"]
embed_weights_base_name = "model.embed_tokens.weight"
checkpoint[embed_weights_base_name] = param_to_weights(embed_weight)
for l in range(int(num_layers)):
print(f"converting layer {l}")
qkv_weights = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.weight"]
qkv_weights = qkv_weights.reshape([qkv_total_dim, -1, hidden_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
## Example of slices
## (without GQA): num_query_groups = head_num = 32,
## q_slice = [0, 3, 6, 9 , ... 90, 93]
## k_slice = [1, 4, 7, 10, ... 91, 94]
## v_slice = [2, 5, 8, 11, ... 92, 95]
## (with GQA): num_query_groups = 8, head_num = 64
## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77]
## k_slice = [8, 18, 28, ... 68, 78]
## v_slice = [9, 19, 29, ... 69, 79]
q_weights_base_name = f"model.layers.{l}.self_attn.q_proj.weight"
k_weights_base_name = f"model.layers.{l}.self_attn.k_proj.weight"
v_weights_base_name = f"model.layers.{l}.self_attn.v_proj.weight"
checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
# attention dense
o_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_proj.weight"]
o_weight_base_name = f"model.layers.{l}.self_attn.o_proj.weight"
checkpoint[o_weight_base_name] = param_to_weights(o_weight)
# mlp
mlp_weights = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.weight"]
mlp_up_proj_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc2.weight"]
if mlp_weights.shape[0] != mlp_up_proj_weight.shape[1]:
# Has projection (used for swi-glu)
logging.warning(
"Gated projection layers detected in NeMo checkpoint. Currently Nemotron HF does not support gated MLP."
)
assert mlp_weights.shape[0] == 2 * mlp_up_proj_weight.shape[1]
mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :]
mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :]
mlp_down_proj_base_name = f"model.layers.{l}.mlp.gate_proj.weight"
mlp_gate_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
else:
mlp_down_proj_weight = mlp_weights
mlp_down_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
mlp_up_proj_base_name = f"model.layers.{l}.mlp.down_proj.weight"
checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
# layernorm
input_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight"]
input_ln_base_name = f"model.layers.{l}.input_layernorm.weight"
checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight)
if (
model.state_dict().get(f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias", None)
is not None
):
input_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias"]
input_ln_bias_name = f"model.layers.{l}.input_layernorm.bias"
checkpoint[input_ln_bias_name] = param_to_weights(input_ln_bias)
post_attn_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight"]
post_attn_ln_base_name = f"model.layers.{l}.post_attention_layernorm.weight"
checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
if model.state_dict().get(f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias", None) is not None:
post_attn_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias"]
post_attn_ln_bias_name = f"model.layers.{l}.post_attention_layernorm.bias"
checkpoint[post_attn_ln_bias_name] = param_to_weights(post_attn_ln_bias)
print(f"done layer {l}")
final_ln_weight = model.state_dict()["model.decoder.final_layernorm.weight"]
final_ln_base_name = "model.norm.weight"
checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight)
if model.state_dict().get("model.decoder.final_layernorm.bias", None) is not None:
final_ln_bias = model.state_dict()["model.decoder.final_layernorm.bias"]
final_ln_bias_name = "model.norm.bias"
checkpoint[final_ln_bias_name] = param_to_weights(final_ln_bias)
output_layer_weight = model.state_dict()["model.output_layer.weight"]
output_layer_base_name = "lm_head.weight"
checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)
os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
torch.save(checkpoint, output_hf_file)
logging.info(f"Weights saved to {output_hf_file}")
return model_config, model.tokenizer, dtype, vocab_size
def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tokenizer):
tokenizer_cfg = model_config.tokenizer
if tokenizer_cfg.library == "sentencepiece":
# For sentencepiece tokenizer, we are wrapping with HF's LlamaTokenizer
# and convert it to a PreTrainedTokenizerFast
tokenizer_fn = tokenizer_cfg.model[5:]
output_tokenizer = f"{output_hf_path}/tokenizer.model"
if nemo_file.endswith(".nemo"):
import tarfile
archive = tarfile.open(nemo_file, "r")
tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix
archive.extract(tokenizer_filename, output_hf_path)
archive.close()
os.rename(f"{output_hf_path}/{tokenizer_fn}", output_tokenizer)
elif os.path.isdir(nemo_file):
shutil.copy(f"{nemo_file}/{tokenizer_fn}", output_tokenizer)
# We use LlamaTokenizer for sentencepiece based tokenizer
tokenizer = LlamaTokenizer.from_pretrained(output_hf_path, legacy=False)
# Convert the LlamaTokenizer to a PreTrainedTokenizerFast instance
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=LlamaConverter(tokenizer).converted(), model_input_names=["input_ids", "token_type_ids"]
)
tokenizer.save_pretrained(output_hf_path)
logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}")
elif isinstance(nemo_tokenizer, AutoTokenizer):
nemo_tokenizer.tokenizer.save_pretrained(output_hf_path)
logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}")
else:
raise ValueError(f"Unsupported tokenizer type: library: {tokenizer_cfg.library}, type: {tokenizer_cfg.type}")
if __name__ == "__main__":
args = get_args()
if not args.hf_output_path:
assert args.output_path is not None, "Need to provide either output_path or hf_output_path"
else:
args.output_path = f"{args.hf_output_path}/pytorch_model.bin"
logging.info(f"weight will be saved to {args.output_path}")
nemo_config, nemo_tokenizer, dtype, vocab_size = convert(
args.input_name_or_path, args.output_path, precision=args.precision, cpu_only=args.cpu_only
)
if args.hf_input_path and args.hf_output_path:
convert_hf_config(nemo_config, nemo_tokenizer, vocab_size, dtype, args.hf_output_path, args.hf_input_path)
extract_nemotron_tokenizer(args.input_name_or_path, nemo_config, args.hf_output_path, nemo_tokenizer)
else:
logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.")
logging.info(f".bin file is saved to {args.output_path}")
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Nemotron model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Size, Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_nemotron import NemotronConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "NemotronConfig"
def _cast_if_autocast_enabled(*args):
if not torch.is_autocast_enabled():
return args
else:
return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
class NemotronLayerNorm1P(nn.LayerNorm):
def __init__(
self,
normalized_shape: Union[int, List[int], Size],
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
device=None,
dtype=None,
):
super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
def forward(self, input: Tensor) -> Tensor:
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
with torch.cuda.amp.autocast(enabled=False):
return F.layer_norm(*args)
ALL_LAYERNORM_LAYERS.append(NemotronLayerNorm1P)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronRotaryEmbedding(nn.Module):
# Ignore copy
def __init__(
self,
config: NemotronConfig,
device=None,
):
super().__init__()
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_kwargs = None
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
rot_dim = cos.shape[-1]
# If q_pass/k_pass is empty, rotary pos embedding is applied to all tensor q/k
q, q_pass = q[..., :rot_dim], q[..., rot_dim:]
k, k_pass = k[..., :rot_dim], k[..., rot_dim:]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1)
class NemotronMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class NemotronAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: NemotronConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronFlashAttention2(NemotronAttention):
"""
Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (NemotronRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronSdpaAttention(NemotronAttention):
"""
Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`NemotronAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"NemotronModel is using NemotronSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
NEMOTRON_ATTENTION_CLASSES = {
"eager": NemotronAttention,
"flash_attention_2": NemotronFlashAttention2,
"sdpa": NemotronSdpaAttention,
}
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronDecoderLayer(nn.Module):
# Ignore copy
def __init__(self, config: NemotronConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = NEMOTRON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = NemotronMLP(config)
self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
NEMOTRON_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`NemotronConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Nemotron Model outputting raw hidden-states without any specific head on top.",
NEMOTRON_START_DOCSTRING,
)
class NemotronPreTrainedModel(PreTrainedModel):
config_class = NemotronConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NemotronDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
NEMOTRON_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare Nemotron Model outputting raw hidden-states without any specific head on top.",
NEMOTRON_START_DOCSTRING,
)
class NemotronModel(NemotronPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NemotronDecoderLayer`]
Args:
config: NemotronConfig
"""
def __init__(self, config: NemotronConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[NemotronDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
self.rotary_emb = NemotronRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForCausalLM(NemotronPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = NemotronModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy (doc string different)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, NemotronForCausalLM
>>> model = NemotronForCausalLM.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
@add_start_docstrings(
"""
The Nemotron Model transformer with a sequence classification head on top (linear layer).
[`NemotronForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
NEMOTRON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForSequenceClassification(NemotronPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = NemotronModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Nemotron Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
NEMOTRON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForQuestionAnswering(NemotronPreTrainedModel):
base_model_prefix = "transformer"
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Nemotron
def __init__(self, config):
super().__init__(config)
self.transformer = NemotronModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
self.transformer.embed_tokens = value
@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The Nemotron Model transformer with a token classification head on top (a linear layer on top of the hidden-states
output) e.g. for Named-Entity-Recognition (NER) tasks.
""",
NEMOTRON_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForTokenClassification(NemotronPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = NemotronModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -6338,6 +6338,48 @@ class MvpPreTrainedModel(metaclass=DummyObject): ...@@ -6338,6 +6338,48 @@ class MvpPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class NemotronForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NllbMoeForConditionalGeneration(metaclass=DummyObject): class NllbMoeForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Nemotron model."""
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import NemotronConfig, is_torch_available
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
from ...test_configuration_common import ConfigTester
if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
NemotronForCausalLM,
NemotronForQuestionAnswering,
NemotronForSequenceClassification,
NemotronForTokenClassification,
NemotronModel,
)
class NemotronModelTester(GemmaModelTester):
if is_torch_available():
config_class = NemotronConfig
model_class = NemotronModel
for_causal_lm_class = NemotronForCausalLM
for_sequence_class = NemotronForSequenceClassification
for_token_class = NemotronForTokenClassification
@require_torch
class NemotronModelTest(GemmaModelTest):
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
all_model_classes = (
(
NemotronModel,
NemotronForCausalLM,
NemotronForSequenceClassification,
NemotronForQuestionAnswering,
NemotronForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (NemotronForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": NemotronModel,
"text-classification": NemotronForSequenceClassification,
"text-generation": NemotronForCausalLM,
"zero-shot": NemotronForSequenceClassification,
"question-answering": NemotronForQuestionAnswering,
"token-classification": NemotronForTokenClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
# used in `test_torch_compile`
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
def setUp(self):
self.model_tester = NemotronModelTester(self)
self.config_tester = ConfigTester(self, config_class=NemotronConfig, hidden_size=37)
@require_torch_sdpa
@slow
@unittest.skip(
reason="Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16."
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
pass
@unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails")
def test_model_outputs_equivalence(self, **kwargs):
pass
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(reason="Model does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
)
model_sdpa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
# nemotron sdpa needs a high tolerance
assert torch.allclose(logits_sdpa, logits, atol=1e-2)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
# nemotron flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=1e-2)
@require_torch_gpu
class NemotronIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@slow
@require_read_token
def test_nemotron_8b_generation_sdpa(self):
text = ["What is the largest planet in solar system?"]
EXPECTED_TEXT = [
"What is the largest planet in solar system?\nAnswer: Jupiter\n\nWhat is the answer",
]
model_id = "thhaus/nemotron3-8b"
model = NemotronForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
@slow
@require_read_token
def test_nemotron_8b_generation_eager(self):
text = ["What is the largest planet in solar system?"]
EXPECTED_TEXT = [
"What is the largest planet in solar system?\nAnswer: Jupiter\n\nWhat is the answer",
]
model_id = "thhaus/nemotron3-8b"
model = NemotronForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
@slow
@require_read_token
def test_nemotron_8b_generation_fa2(self):
text = ["What is the largest planet in solar system?"]
EXPECTED_TEXT = [
"What is the largest planet in solar system?\nAnswer: Jupiter\n\nWhat is the answer",
]
model_id = "thhaus/nemotron3-8b"
model = NemotronForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment