"examples/vscode:/vscode.git/clone" did not exist on "34abee090750d22fef357bfb1ffd564c961b9e1d"
Unverified Commit 6a03942d authored by Ao Tang's avatar Ao Tang Committed by GitHub
Browse files

Add Nemotron HF Support (#31699)

* Add nemotron support

* fix inference

* add unit test

* add layernorm1p as a class to avoid meta device mismatch

* test fixed

* Add copied_from statements

* remove pretraining_tp args

* remove nemotronlayernorm

* force LN computation done in FP32

* remove nemotrontokenizer and use llamatokenizer

* license update

* add option for kv_channels for minitron8b

* remove assert

* o_proj fixed

* o_proj reshape

* add gated_proj option

* typo

* remove todos

* fix broken test after merging latest main

* remove nezha/nat after meging main

* chnage default config to 15b model

* add nemo conversion script

* rename conversion script

* remove gate_proj option

* pr comment resolved

* fix unit test

* rename kv_channels to head_dim

* resolve PR issue

* add nemotron md

* fix broken tests

* refactor rope for nemotron

* test fix

* remove linearscaling

* whitespace and import

* fix some copied-from

* code style fix

* reformatted

* add position_embedding to nemotronattention

* rope refactor to only use config, copied-from fix

* format

* Run make fix-copies

* nemotron md with autodoc

* doc  fix

* fix order

* pass check_config_docstrings.py

* fix config_attributes

* remove all llama BC related code

* Use PreTrainedTokenizerFast

* ruff check examples

* conversion script update

* add nemotron to toctree
parent 36fd35e1
...@@ -468,6 +468,8 @@ ...@@ -468,6 +468,8 @@
title: MT5 title: MT5
- local: model_doc/mvp - local: model_doc/mvp
title: MVP title: MVP
- local: model_doc/nemotron
title: Nemotron
- local: model_doc/nezha - local: model_doc/nezha
title: NEZHA title: NEZHA
- local: model_doc/nllb - local: model_doc/nllb
......
...@@ -222,6 +222,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -222,6 +222,7 @@ Flax), PyTorch, and/or TensorFlow.
| [MusicGen Melody](model_doc/musicgen_melody) | ✅ | ❌ | ❌ | | [MusicGen Melody](model_doc/musicgen_melody) | ✅ | ❌ | ❌ |
| [MVP](model_doc/mvp) | ✅ | ❌ | ❌ | | [MVP](model_doc/mvp) | ✅ | ❌ | ❌ |
| [NAT](model_doc/nat) | ✅ | ❌ | ❌ | | [NAT](model_doc/nat) | ✅ | ❌ | ❌ |
| [Nemotron](model_doc/nemotron) | ✅ | ❌ | ❌ |
| [Nezha](model_doc/nezha) | ✅ | ❌ | ❌ | | [Nezha](model_doc/nezha) | ✅ | ❌ | ❌ |
| [NLLB](model_doc/nllb) | ✅ | ❌ | ❌ | | [NLLB](model_doc/nllb) | ✅ | ❌ | ❌ |
| [NLLB-MOE](model_doc/nllb-moe) | ✅ | ❌ | ❌ | | [NLLB-MOE](model_doc/nllb-moe) | ✅ | ❌ | ❌ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Nemotron
## Nemotron
### License
The use of this model is governed by the [NVIDIA AI Foundation Models Community License Agreement](https://developer.nvidia.com/downloads/nv-ai-foundation-models-license).
### Description
Nemotron-4 is a family of enterprise ready generative text models compatible with [NVIDIA NeMo Framework](https://www.nvidia.com/en-us/ai-data-science/generative-ai/nemo-framework/).
NVIDIA NeMo is an end-to-end, cloud-native platform to build, customize, and deploy generative AI models anywhere. It includes training and inferencing frameworks, guardrailing toolkits, data curation tools, and pretrained models, offering enterprises an easy, cost-effective, and fast way to adopt generative AI. To get access to NeMo Framework, please sign up at [this link](https://developer.nvidia.com/nemo-framework/join).
### References
[Announcement Blog](https://developer.nvidia.com/blog/nvidia-ai-foundation-models-build-custom-enterprise-chatbots-and-co-pilots-with-production-ready-llms/)
### Model Architecture
**Architecture Type:** Transformer
**Network Architecture:** Transformer Decoder (auto-regressive language model).
## Minitron
### Minitron 4B Base
Minitron is a family of small language models (SLMs) obtained by pruning NVIDIA's [Nemotron-4 15B](https://arxiv.org/abs/2402.16819) model. We prune model embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.
Deriving the Minitron 8B and 4B models from the base 15B model using our approach requires up to **40x fewer training tokens** per model compared to training from scratch; this results in **compute cost savings of 1.8x** for training the full model family (15B, 8B, and 4B). Minitron models exhibit up to a 16% improvement in MMLU scores compared to training from scratch, perform comparably to other community models such as Mistral 7B, Gemma 7B and Llama-3 8B, and outperform state-of-the-art compression techniques from the literature. Please refer to our [arXiv paper](https://arxiv.org/abs/2407.14679) for more details.
Minitron models are for research and development only.
### HuggingFace Quickstart
The following code provides an example of how to load the Minitron-4B model and use it to perform text generation.
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the tokenizer and model
model_path = 'nvidia/Minitron-4B-Base'
tokenizer = AutoTokenizer.from_pretrained(model_path)
device = 'cuda'
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
# Prepare the input text
prompt = 'Complete the paragraph: our solar system is'
inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
# Generate the output
outputs = model.generate(inputs, max_length=20)
# Decode and print the output
output_text = tokenizer.decode(outputs[0])
print(output_text)
```
### License
Minitron is released under the [NVIDIA Open Model License Agreement](https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf).
### Evaluation Results
*5-shot performance.* Language Understanding evaluated using [Massive Multitask Language Understanding](https://arxiv.org/abs/2009.03300):
| Average |
| :---- |
| 58.6 |
*Zero-shot performance.* Evaluated using select datasets from the [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) with additions:
| HellaSwag | Winogrande | GSM8K| ARC-C | XLSum |
| :------------- | :------------- | :------------- | :------------- | :------------- |
| 75.0 | 74.0 | 24.1 | 50.9 | 29.5
*Code generation performance*. Evaluated using [HumanEval](https://github.com/openai/human-eval):
| p@1, 0-Shot |
| :------------- |
| 23.3 |
Please refer to our [paper](https://arxiv.org/abs/2407.14679) for the full set of results.
### Citation
If you find our work helpful, please consider citing our paper:
```
@article{minitron2024,
title={Compact Language Models via Pruning and Knowledge Distillation},
author={Saurav Muralidharan and Sharath Turuvekere Sreenivas and Raviraj Joshi and Marcin Chochowski and Mostofa Patwary and Mohammad Shoeybi and Bryan Catanzaro and Jan Kautz and Pavlo Molchanov},
journal={arXiv preprint arXiv:2407.14679},
year={2024},
url={https://arxiv.org/abs/2407.14679},
}
```
## NemotronConfig
[[autodoc]] NemotronConfig
## NemotronModel
[[autodoc]] NemotronModel
- forward
## NemotronForCausalLM
[[autodoc]] NemotronForCausalLM
- forward
## NemotronForSequenceClassification
[[autodoc]] NemotronForSequenceClassification
- forward
## NemotronForQuestionAnswering
[[autodoc]] NemotronForQuestionAnswering
- forward
## NemotronForTokenClassification
[[autodoc]] NemotronForTokenClassification
- forward
\ No newline at end of file
...@@ -67,6 +67,7 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -67,6 +67,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
...@@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel) * [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel) * [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) * [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
......
...@@ -592,6 +592,7 @@ _import_structure = { ...@@ -592,6 +592,7 @@ _import_structure = {
"MusicgenMelodyDecoderConfig", "MusicgenMelodyDecoderConfig",
], ],
"models.mvp": ["MvpConfig", "MvpTokenizer"], "models.mvp": ["MvpConfig", "MvpTokenizer"],
"models.nemotron": ["NemotronConfig"],
"models.nllb": [], "models.nllb": [],
"models.nllb_moe": ["NllbMoeConfig"], "models.nllb_moe": ["NllbMoeConfig"],
"models.nougat": ["NougatProcessor"], "models.nougat": ["NougatProcessor"],
...@@ -2742,6 +2743,16 @@ else: ...@@ -2742,6 +2743,16 @@ else:
"MvpPreTrainedModel", "MvpPreTrainedModel",
] ]
) )
_import_structure["models.nemotron"].extend(
[
"NemotronForCausalLM",
"NemotronForQuestionAnswering",
"NemotronForSequenceClassification",
"NemotronForTokenClassification",
"NemotronModel",
"NemotronPreTrainedModel",
]
)
_import_structure["models.nllb_moe"].extend( _import_structure["models.nllb_moe"].extend(
[ [
"NllbMoeForConditionalGeneration", "NllbMoeForConditionalGeneration",
...@@ -5286,6 +5297,7 @@ if TYPE_CHECKING: ...@@ -5286,6 +5297,7 @@ if TYPE_CHECKING:
MusicgenMelodyDecoderConfig, MusicgenMelodyDecoderConfig,
) )
from .models.mvp import MvpConfig, MvpTokenizer from .models.mvp import MvpConfig, MvpTokenizer
from .models.nemotron import NemotronConfig
from .models.nllb_moe import NllbMoeConfig from .models.nllb_moe import NllbMoeConfig
from .models.nougat import NougatProcessor from .models.nougat import NougatProcessor
from .models.nystromformer import ( from .models.nystromformer import (
...@@ -7187,6 +7199,14 @@ if TYPE_CHECKING: ...@@ -7187,6 +7199,14 @@ if TYPE_CHECKING:
MvpModel, MvpModel,
MvpPreTrainedModel, MvpPreTrainedModel,
) )
from .models.nemotron import (
NemotronForCausalLM,
NemotronForQuestionAnswering,
NemotronForSequenceClassification,
NemotronForTokenClassification,
NemotronModel,
NemotronPreTrainedModel,
)
from .models.nllb_moe import ( from .models.nllb_moe import (
NllbMoeForConditionalGeneration, NllbMoeForConditionalGeneration,
NllbMoeModel, NllbMoeModel,
......
...@@ -159,6 +159,7 @@ from . import ( ...@@ -159,6 +159,7 @@ from . import (
musicgen, musicgen,
musicgen_melody, musicgen_melody,
mvp, mvp,
nemotron,
nllb, nllb,
nllb_moe, nllb_moe,
nougat, nougat,
......
...@@ -177,6 +177,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -177,6 +177,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("musicgen_melody", "MusicgenMelodyConfig"), ("musicgen_melody", "MusicgenMelodyConfig"),
("mvp", "MvpConfig"), ("mvp", "MvpConfig"),
("nat", "NatConfig"), ("nat", "NatConfig"),
("nemotron", "NemotronConfig"),
("nezha", "NezhaConfig"), ("nezha", "NezhaConfig"),
("nllb-moe", "NllbMoeConfig"), ("nllb-moe", "NllbMoeConfig"),
("nougat", "VisionEncoderDecoderConfig"), ("nougat", "VisionEncoderDecoderConfig"),
...@@ -469,6 +470,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -469,6 +470,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("musicgen_melody", "MusicGen Melody"), ("musicgen_melody", "MusicGen Melody"),
("mvp", "MVP"), ("mvp", "MVP"),
("nat", "NAT"), ("nat", "NAT"),
("nemotron", "Nemotron"),
("nezha", "Nezha"), ("nezha", "Nezha"),
("nllb", "NLLB"), ("nllb", "NLLB"),
("nllb-moe", "NLLB-MOE"), ("nllb-moe", "NLLB-MOE"),
......
...@@ -169,6 +169,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -169,6 +169,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("musicgen_melody", "MusicgenMelodyModel"), ("musicgen_melody", "MusicgenMelodyModel"),
("mvp", "MvpModel"), ("mvp", "MvpModel"),
("nat", "NatModel"), ("nat", "NatModel"),
("nemotron", "NemotronModel"),
("nezha", "NezhaModel"), ("nezha", "NezhaModel"),
("nllb-moe", "NllbMoeModel"), ("nllb-moe", "NllbMoeModel"),
("nystromformer", "NystromformerModel"), ("nystromformer", "NystromformerModel"),
...@@ -481,6 +482,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -481,6 +482,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("musicgen", "MusicgenForCausalLM"), ("musicgen", "MusicgenForCausalLM"),
("musicgen_melody", "MusicgenMelodyForCausalLM"), ("musicgen_melody", "MusicgenMelodyForCausalLM"),
("mvp", "MvpForCausalLM"), ("mvp", "MvpForCausalLM"),
("nemotron", "NemotronForCausalLM"),
("olmo", "OlmoForCausalLM"), ("olmo", "OlmoForCausalLM"),
("open-llama", "OpenLlamaForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"), ("openai-gpt", "OpenAIGPTLMHeadModel"),
...@@ -902,6 +904,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -902,6 +904,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mra", "MraForSequenceClassification"), ("mra", "MraForSequenceClassification"),
("mt5", "MT5ForSequenceClassification"), ("mt5", "MT5ForSequenceClassification"),
("mvp", "MvpForSequenceClassification"), ("mvp", "MvpForSequenceClassification"),
("nemotron", "NemotronForSequenceClassification"),
("nezha", "NezhaForSequenceClassification"), ("nezha", "NezhaForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"),
("open-llama", "OpenLlamaForSequenceClassification"), ("open-llama", "OpenLlamaForSequenceClassification"),
...@@ -983,6 +986,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -983,6 +986,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("mra", "MraForQuestionAnswering"), ("mra", "MraForQuestionAnswering"),
("mt5", "MT5ForQuestionAnswering"), ("mt5", "MT5ForQuestionAnswering"),
("mvp", "MvpForQuestionAnswering"), ("mvp", "MvpForQuestionAnswering"),
("nemotron", "NemotronForQuestionAnswering"),
("nezha", "NezhaForQuestionAnswering"), ("nezha", "NezhaForQuestionAnswering"),
("nystromformer", "NystromformerForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"),
("opt", "OPTForQuestionAnswering"), ("opt", "OPTForQuestionAnswering"),
...@@ -1078,6 +1082,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -1078,6 +1082,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("mpt", "MptForTokenClassification"), ("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"), ("mra", "MraForTokenClassification"),
("mt5", "MT5ForTokenClassification"), ("mt5", "MT5ForTokenClassification"),
("nemotron", "NemotronForTokenClassification"),
("nezha", "NezhaForTokenClassification"), ("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"),
("persimmon", "PersimmonForTokenClassification"), ("persimmon", "PersimmonForTokenClassification"),
......
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_torch_available,
)
_import_structure = {
"configuration_nemotron": ["NemotronConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_nemotron"] = [
"NemotronForQuestionAnswering",
"NemotronForCausalLM",
"NemotronModel",
"NemotronPreTrainedModel",
"NemotronForSequenceClassification",
"NemotronForTokenClassification",
]
if TYPE_CHECKING:
from .configuration_nemotron import NemotronConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_nemotron import (
NemotronForCausalLM,
NemotronForQuestionAnswering,
NemotronForSequenceClassification,
NemotronForTokenClassification,
NemotronModel,
NemotronPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Nemotron model configuration"""
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
logger = logging.get_logger(__name__)
class NemotronConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`NemotronModel`]. It is used to instantiate an Nemotron
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Nemotron-8B.
e.g. [nvidia/nemotron-3-8b-base-4k-hf](https://huggingface.co/nvidia/nemotron-3-8b-base-4k-hf).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Nemotron model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer decoder.
head_dim (`int`, *optional*):
Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.0134):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 3):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj and down_proj layers in the MLP layers.
```python
>>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "nemotron"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=6144,
intermediate_size=24576,
num_hidden_layers=32,
num_attention_heads=48,
head_dim=None,
num_key_value_heads=None,
hidden_act="relu2",
max_position_embeddings=4096,
initializer_range=0.0134,
norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=2,
eos_token_id=3,
tie_word_embeddings=False,
rope_theta=10000.0,
partial_rotary_factor=0.5,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor
rope_config_validation(self)
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
from argparse import ArgumentParser
from collections import OrderedDict
import torch
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils import logging
from pytorch_lightning import Trainer
from transformers import LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import LlamaConverter
"""
Script to convert a nemotron checkpoint in nemo (mcore path) into a HuggingFace checkpoint.
This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder.
1) Generate only HF weights from a nemo file:
python convert_nemotron_nemo_to_hf.py \
--input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
--output_path /path/to/pytorch_model.bin
2) Generate the full HF model folder
python convert_nemotron_nemo_to_hf.py \
--input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
--hf_input_path /path/to/input_hf_folder \
--hf_output_path /path/to/output_hf_folder \
Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Nemotron4 340b).
However this option makes the conversion script significantly slower.
"""
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--input_name_or_path",
type=str,
default=None,
required=True,
help="Path to .nemo file or extracted folder",
)
parser.add_argument("--output_path", type=str, default=None, required=False, help="Path to HF .bin file")
parser.add_argument(
"--hf_input_path",
type=str,
default=None,
help="A HF model path, " "e.g. a folder containing https://huggingface.co/nvidia/Minitron-8B-Base",
)
parser.add_argument(
"--hf_output_path",
type=str,
default=None,
help="Output HF model path, " "with the same format as above but user's own weights",
)
parser.add_argument(
"--precision",
type=str,
default=None,
help="Precision of output weights."
"Defaults to precision of the input nemo weights (model.cfg.trainer.precision)",
)
parser.add_argument(
"--cpu-only",
action="store_true",
help="Load model in cpu only. Useful if the model cannot fit in GPU memory, "
"but this option makes the conversion script significantly slower.",
)
args = parser.parse_args()
return args
def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, hf_url="nvidia/Minitron-8B-Base"):
"""
Convert NeMo config to HF config
"""
NEMO_ACT2HF = {
"squared-relu": "relu2",
"fast-swiglu": "silu",
}
DTYPE2HF = {
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
}
hf_config = {
"_name_or_path": hf_url,
"architectures": ["NemotronForCausalLM"],
"bos_token_id": tokenizer.bos_id,
"eos_token_id": tokenizer.eos_id,
"hidden_act": NEMO_ACT2HF[nemo_config.activation],
"hidden_size": nemo_config.hidden_size,
"initializer_range": nemo_config.init_method_std,
"intermediate_size": nemo_config.ffn_hidden_size,
"max_position_embeddings": nemo_config.max_position_embeddings,
"model_type": "nemotron",
"num_attention_heads": nemo_config.num_attention_heads,
"num_hidden_layers": nemo_config.num_layers,
"num_key_value_heads": nemo_config.get("num_query_groups", nemo_config.num_attention_heads),
"norm_eps": nemo_config.layernorm_epsilon,
"rope_theta": nemo_config.get("rotary_base", 10000),
"partial_rotary_factor": nemo_config.get("rotary_percentage", 1.0),
"tie_word_embeddings": False,
"torch_dtype": DTYPE2HF[dtype],
"transformers_version": "4.32.0.dev0", # TODO
"use_cache": True,
"vocab_size": vocab_size,
}
if nemo_config.kv_channels is not None:
hf_config["kv_channels"] = nemo_config.kv_channels
json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2)
def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None:
"""
Convert NeMo weights to HF weights
"""
dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy())
model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True)
model_config.tensor_model_parallel_size = 1
model_config.pipeline_model_parallel_size = 1
model_config.sequence_parallel = False
model_config.transformer_engine = True
if cpu_only:
map_location = torch.device("cpu")
model_config.use_cpu_initialization = True
model_config.dist_ckpt_load_on_device = False
else:
map_location = None
if cpu_only:
logging.info("******** Loading model on CPU. This will take a significant amount of time.")
model = MegatronGPTModel.restore_from(
input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
)
vocab_size = model.padded_vocab_size
if precision is None:
precision = model.cfg.precision
if precision in [32, "32"]:
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
logging.warning(f"Precision string {precision} is not recognized, falling back to fp32")
dtype = torch.float32 # fallback
logging.info(f"Using precision {dtype}")
def param_to_weights(param):
return param.to(dtype)
checkpoint = OrderedDict()
hidden_size = model.cfg.hidden_size
head_num = model.cfg.num_attention_heads
num_layers = model.cfg.num_layers
ffn_hidden_size = model.cfg.ffn_hidden_size
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B
if num_query_groups is None:
num_query_groups = head_num
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups
# Embedding
embed_weight = model.state_dict()["model.embedding.word_embeddings.weight"]
embed_weights_base_name = "model.embed_tokens.weight"
checkpoint[embed_weights_base_name] = param_to_weights(embed_weight)
for l in range(int(num_layers)):
print(f"converting layer {l}")
qkv_weights = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.weight"]
qkv_weights = qkv_weights.reshape([qkv_total_dim, -1, hidden_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
## Example of slices
## (without GQA): num_query_groups = head_num = 32,
## q_slice = [0, 3, 6, 9 , ... 90, 93]
## k_slice = [1, 4, 7, 10, ... 91, 94]
## v_slice = [2, 5, 8, 11, ... 92, 95]
## (with GQA): num_query_groups = 8, head_num = 64
## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77]
## k_slice = [8, 18, 28, ... 68, 78]
## v_slice = [9, 19, 29, ... 69, 79]
q_weights_base_name = f"model.layers.{l}.self_attn.q_proj.weight"
k_weights_base_name = f"model.layers.{l}.self_attn.k_proj.weight"
v_weights_base_name = f"model.layers.{l}.self_attn.v_proj.weight"
checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
# attention dense
o_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_proj.weight"]
o_weight_base_name = f"model.layers.{l}.self_attn.o_proj.weight"
checkpoint[o_weight_base_name] = param_to_weights(o_weight)
# mlp
mlp_weights = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.weight"]
mlp_up_proj_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc2.weight"]
if mlp_weights.shape[0] != mlp_up_proj_weight.shape[1]:
# Has projection (used for swi-glu)
logging.warning(
"Gated projection layers detected in NeMo checkpoint. Currently Nemotron HF does not support gated MLP."
)
assert mlp_weights.shape[0] == 2 * mlp_up_proj_weight.shape[1]
mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :]
mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :]
mlp_down_proj_base_name = f"model.layers.{l}.mlp.gate_proj.weight"
mlp_gate_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
else:
mlp_down_proj_weight = mlp_weights
mlp_down_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
mlp_up_proj_base_name = f"model.layers.{l}.mlp.down_proj.weight"
checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
# layernorm
input_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight"]
input_ln_base_name = f"model.layers.{l}.input_layernorm.weight"
checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight)
if (
model.state_dict().get(f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias", None)
is not None
):
input_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias"]
input_ln_bias_name = f"model.layers.{l}.input_layernorm.bias"
checkpoint[input_ln_bias_name] = param_to_weights(input_ln_bias)
post_attn_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight"]
post_attn_ln_base_name = f"model.layers.{l}.post_attention_layernorm.weight"
checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
if model.state_dict().get(f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias", None) is not None:
post_attn_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias"]
post_attn_ln_bias_name = f"model.layers.{l}.post_attention_layernorm.bias"
checkpoint[post_attn_ln_bias_name] = param_to_weights(post_attn_ln_bias)
print(f"done layer {l}")
final_ln_weight = model.state_dict()["model.decoder.final_layernorm.weight"]
final_ln_base_name = "model.norm.weight"
checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight)
if model.state_dict().get("model.decoder.final_layernorm.bias", None) is not None:
final_ln_bias = model.state_dict()["model.decoder.final_layernorm.bias"]
final_ln_bias_name = "model.norm.bias"
checkpoint[final_ln_bias_name] = param_to_weights(final_ln_bias)
output_layer_weight = model.state_dict()["model.output_layer.weight"]
output_layer_base_name = "lm_head.weight"
checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)
os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
torch.save(checkpoint, output_hf_file)
logging.info(f"Weights saved to {output_hf_file}")
return model_config, model.tokenizer, dtype, vocab_size
def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tokenizer):
tokenizer_cfg = model_config.tokenizer
if tokenizer_cfg.library == "sentencepiece":
# For sentencepiece tokenizer, we are wrapping with HF's LlamaTokenizer
# and convert it to a PreTrainedTokenizerFast
tokenizer_fn = tokenizer_cfg.model[5:]
output_tokenizer = f"{output_hf_path}/tokenizer.model"
if nemo_file.endswith(".nemo"):
import tarfile
archive = tarfile.open(nemo_file, "r")
tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix
archive.extract(tokenizer_filename, output_hf_path)
archive.close()
os.rename(f"{output_hf_path}/{tokenizer_fn}", output_tokenizer)
elif os.path.isdir(nemo_file):
shutil.copy(f"{nemo_file}/{tokenizer_fn}", output_tokenizer)
# We use LlamaTokenizer for sentencepiece based tokenizer
tokenizer = LlamaTokenizer.from_pretrained(output_hf_path, legacy=False)
# Convert the LlamaTokenizer to a PreTrainedTokenizerFast instance
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=LlamaConverter(tokenizer).converted(), model_input_names=["input_ids", "token_type_ids"]
)
tokenizer.save_pretrained(output_hf_path)
logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}")
elif isinstance(nemo_tokenizer, AutoTokenizer):
nemo_tokenizer.tokenizer.save_pretrained(output_hf_path)
logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}")
else:
raise ValueError(f"Unsupported tokenizer type: library: {tokenizer_cfg.library}, type: {tokenizer_cfg.type}")
if __name__ == "__main__":
args = get_args()
if not args.hf_output_path:
assert args.output_path is not None, "Need to provide either output_path or hf_output_path"
else:
args.output_path = f"{args.hf_output_path}/pytorch_model.bin"
logging.info(f"weight will be saved to {args.output_path}")
nemo_config, nemo_tokenizer, dtype, vocab_size = convert(
args.input_name_or_path, args.output_path, precision=args.precision, cpu_only=args.cpu_only
)
if args.hf_input_path and args.hf_output_path:
convert_hf_config(nemo_config, nemo_tokenizer, vocab_size, dtype, args.hf_output_path, args.hf_input_path)
extract_nemotron_tokenizer(args.input_name_or_path, nemo_config, args.hf_output_path, nemo_tokenizer)
else:
logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.")
logging.info(f".bin file is saved to {args.output_path}")
This diff is collapsed.
...@@ -6338,6 +6338,48 @@ class MvpPreTrainedModel(metaclass=DummyObject): ...@@ -6338,6 +6338,48 @@ class MvpPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class NemotronForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NemotronPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NllbMoeForConditionalGeneration(metaclass=DummyObject): class NllbMoeForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Nemotron model."""
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import NemotronConfig, is_torch_available
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
from ...test_configuration_common import ConfigTester
if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
NemotronForCausalLM,
NemotronForQuestionAnswering,
NemotronForSequenceClassification,
NemotronForTokenClassification,
NemotronModel,
)
class NemotronModelTester(GemmaModelTester):
if is_torch_available():
config_class = NemotronConfig
model_class = NemotronModel
for_causal_lm_class = NemotronForCausalLM
for_sequence_class = NemotronForSequenceClassification
for_token_class = NemotronForTokenClassification
@require_torch
class NemotronModelTest(GemmaModelTest):
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
all_model_classes = (
(
NemotronModel,
NemotronForCausalLM,
NemotronForSequenceClassification,
NemotronForQuestionAnswering,
NemotronForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (NemotronForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": NemotronModel,
"text-classification": NemotronForSequenceClassification,
"text-generation": NemotronForCausalLM,
"zero-shot": NemotronForSequenceClassification,
"question-answering": NemotronForQuestionAnswering,
"token-classification": NemotronForTokenClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
# used in `test_torch_compile`
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
def setUp(self):
self.model_tester = NemotronModelTester(self)
self.config_tester = ConfigTester(self, config_class=NemotronConfig, hidden_size=37)
@require_torch_sdpa
@slow
@unittest.skip(
reason="Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16."
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
pass
@unittest.skip("Eager and SDPA do not produce the same outputs, thus this test fails")
def test_model_outputs_equivalence(self, **kwargs):
pass
@require_torch_sdpa
@require_torch_gpu
@slow
def test_sdpa_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_sdpa:
self.skipTest(reason="Model does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
)
model_sdpa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
# nemotron sdpa needs a high tolerance
assert torch.allclose(logits_sdpa, logits, atol=1e-2)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
# nemotron flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=1e-2)
@require_torch_gpu
class NemotronIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@slow
@require_read_token
def test_nemotron_8b_generation_sdpa(self):
text = ["What is the largest planet in solar system?"]
EXPECTED_TEXT = [
"What is the largest planet in solar system?\nAnswer: Jupiter\n\nWhat is the answer",
]
model_id = "thhaus/nemotron3-8b"
model = NemotronForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
@slow
@require_read_token
def test_nemotron_8b_generation_eager(self):
text = ["What is the largest planet in solar system?"]
EXPECTED_TEXT = [
"What is the largest planet in solar system?\nAnswer: Jupiter\n\nWhat is the answer",
]
model_id = "thhaus/nemotron3-8b"
model = NemotronForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
@slow
@require_read_token
def test_nemotron_8b_generation_fa2(self):
text = ["What is the largest planet in solar system?"]
EXPECTED_TEXT = [
"What is the largest planet in solar system?\nAnswer: Jupiter\n\nWhat is the answer",
]
model_id = "thhaus/nemotron3-8b"
model = NemotronForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment