Unverified Commit b4d4d6fe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add RWKV-4 (#22797)



* First draft of RWKV-4

* Add support for generate

* Style post-rebase

* Properly use state

* Write doc

* Fix doc

* More math

* Add model to README, dummies and clean config

* Fix init

* multiple fixes:

- fix common tests
- fix configuraion default values
- add CI test for checking state computation
- fix some CI tests

* correct tokenizer

* some tweaks

- fix config docstring
- fix failing tests

* fix CI tests

- add output_attention / output_hidden_states
- override test_initialization
- fix failing CIs

* fix conversion script

- fix sharded case
- add new arguments

* add slow tests + more fixes on conversion script

* add another test

* final fixes

* change single name variable

* add mock attention mask for pipeline to work

* correct eos token id

* fix nits

* add checkpoints

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add `tie_word_embeddings` in docstring

* change tensor name

* fix final nits

* Trigger CI

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9a50cb61
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig", "RwkvOnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_rwkv"] = [
"RWKV_PRETRAINED_MODEL_ARCHIVE_LIST",
"RwkvForCausalLM",
"RwkvModel",
"RwkvPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig, RwkvOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_rwkv import (
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
RwkvForCausalLM,
RwkvModel,
RwkvPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" RWKV configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"RWKV/rwkv-4-169m-pile": "https://huggingface.co/RWKV/rwkv-4-169m-pile/resolve/main/config.json",
"RWKV/rwkv-4-430m-pile": "https://huggingface.co/RWKV/rwkv-4-430m-pile/resolve/main/config.json",
"RWKV/rwkv-4-1b5-pile": "https://huggingface.co/RWKV/rwkv-4-1b5-pile/resolve/main/config.json",
"RWKV/rwkv-4-3b-pile": "https://huggingface.co/RWKV/rwkv-4-3b-pile/resolve/main/config.json",
"RWKV/rwkv-4-7b-pile": "https://huggingface.co/RWKV/rwkv-4-7b-pile/resolve/main/config.json",
"RWKV/rwkv-4-14b-pile": "https://huggingface.co/RWKV/rwkv-4-14b-pile/resolve/main/config.json",
"RWKV/rwkv-raven-1b5": "https://huggingface.co/RWKV/rwkv-raven-1b5/resolve/main/config.json",
"RWKV/rwkv-raven-3b": "https://huggingface.co/RWKV/rwkv-raven-3b/resolve/main/config.json",
"RWKV/rwkv-raven-7b": "https://huggingface.co/RWKV/rwkv-raven-7b/resolve/main/config.json",
"RWKV/rwkv-raven-14b": "https://huggingface.co/RWKV/rwkv-raven-14b/resolve/main/config.json",
}
class RwkvConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the RWVK-4
[RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50277):
Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`RwkvModel`].
context_length (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model can be be used with in a single forward (using it in RNN mode
lets use any sequence length).
hidden_size (`int`, *optional*, defaults to 4096):
Dimensionality of the embeddings and hidden states.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the model.
attention_hidden_size (`int`, *optional*):
Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
intermediate_size (`int`, *optional*):
Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer
as GPTNeoX.
eos_token_id (`int`, *optional*, defaults to 0):
The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as
GPTNeoX.
rescale_every (`int`, *optional*, default to 6):
At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every
`rescale_every` layer. If set to 0 or a negative number, no rescale is done.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to tie the word embeddings with the input token embeddings.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last state.
Example:
```python
>>> from transformers import RwkvConfig, RwkvModel
>>> # Initializing a Rwkv configuration
>>> configuration = RwkvConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = RwkvModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "rwkv"
attribute_map = {"max_position_embeddings": "context_length"}
def __init__(
self,
vocab_size=50277,
context_length=1024,
hidden_size=4096,
num_hidden_layers=32,
attention_hidden_size=None,
intermediate_size=None,
layer_norm_epsilon=1e-5,
bos_token_id=0,
eos_token_id=0,
rescale_every=6,
tie_word_embeddings=False,
use_cache=True,
**kwargs,
):
self.vocab_size = vocab_size
self.context_length = context_length
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size
self.layer_norm_epsilon = layer_norm_epsilon
self.rescale_every = rescale_every
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(
tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a RWKV checkpoint from BlinkDL to the Hugging Face format."""
import argparse
import gc
import json
import os
import re
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint
NUM_HIDDEN_LAYERS_MAPPING = {
"169M": 12,
"430M": 24,
"1B5": 24,
"3B": 32,
"7B": 32,
"14B": 40,
}
HIDEN_SIZE_MAPPING = {
"169M": 768,
"430M": 1024,
"1B5": 2048,
"3B": 2560,
"7B": 4096,
"14B": 5120,
}
def convert_state_dict(state_dict):
state_dict_keys = list(state_dict.keys())
for name in state_dict_keys:
weight = state_dict.pop(name)
# emb -> embedding
if name.startswith("emb."):
name = name.replace("emb.", "embeddings.")
# ln_0 -> pre_ln (only present at block 0)
if name.startswith("blocks.0.ln0"):
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
# att -> attention
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
# ffn -> feed_forward
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
# time_mix_k -> time_mix_key and reshape
if name.endswith(".time_mix_k"):
name = name.replace(".time_mix_k", ".time_mix_key")
# time_mix_v -> time_mix_value and reshape
if name.endswith(".time_mix_v"):
name = name.replace(".time_mix_v", ".time_mix_value")
# time_mix_r -> time_mix_key and reshape
if name.endswith(".time_mix_r"):
name = name.replace(".time_mix_r", ".time_mix_receptance")
if name != "head.weight":
name = "rwkv." + name
state_dict[name] = weight
return state_dict
def convert_rmkv_checkpoint_to_hf_format(
repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None
):
# 1. If possible, build the tokenizer.
if tokenizer_file is None:
print("No `--tokenizer_file` provided, we will use the default tokenizer.")
vocab_size = 50277
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
else:
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
vocab_size = len(tokenizer)
tokenizer.save_pretrained(output_dir)
# 2. Build the config
possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys())
if size is None:
# Try to infer size from the checkpoint name
for candidate in possible_sizes:
if candidate in checkpoint_file:
size = candidate
break
if size is None:
raise ValueError("Could not infer the size, please provide it with the `--size` argument.")
if size not in possible_sizes:
raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.")
config = RwkvConfig(
vocab_size=vocab_size,
num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size],
hidden_size=HIDEN_SIZE_MAPPING[size],
)
config.save_pretrained(output_dir)
# 3. Download model file then convert state_dict
model_file = hf_hub_download(repo_id, checkpoint_file)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = convert_state_dict(state_dict)
# 4. Split in shards and save
shards, index = shard_checkpoint(state_dict)
for shard_file, shard in shards.items():
torch.save(shard, os.path.join(output_dir, shard_file))
if index is not None:
save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
# 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict
print(
"Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model."
)
shard_files = list(shards.keys())
del state_dict
del shards
gc.collect()
for shard_file in shard_files:
state_dict = torch.load(os.path.join(output_dir, shard_file))
torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file))
del state_dict
gc.collect()
if push_to_hub:
if model_name is None:
raise ValueError("Please provide a `model_name` to push the model to the Hub.")
model = AutoModelForCausalLM.from_pretrained(output_dir)
model.push_to_hub(model_name, max_shard_size="2GB")
tokenizer.push_to_hub(model_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint."
)
parser.add_argument(
"--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo."
)
parser.add_argument(
"--output_dir", default=None, type=str, required=True, help="Where to save the converted model."
)
parser.add_argument(
"--tokenizer_file",
default=None,
type=str,
help="Path to the tokenizer file to use (if not provided, only the model is converted).",
)
parser.add_argument(
"--size",
default=None,
type=str,
help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push to the Hub the converted model.",
)
parser.add_argument(
"--model_name",
default=None,
type=str,
help="Name of the pushed model on the Hub, including the username / organization.",
)
args = parser.parse_args()
convert_rmkv_checkpoint_to_hf_format(
args.repo_id,
args.checkpoint_file,
args.output_dir,
size=args.size,
tokenizer_file=args.tokenizer_file,
push_to_hub=args.push_to_hub,
model_name=args.model_name,
)
This diff is collapsed.
...@@ -6105,6 +6105,30 @@ def load_tf_weights_in_roformer(*args, **kwargs): ...@@ -6105,6 +6105,30 @@ def load_tf_weights_in_roformer(*args, **kwargs):
requires_backends(load_tf_weights_in_roformer, ["torch"]) requires_backends(load_tf_weights_in_roformer, ["torch"])
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = None
class RwkvForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RwkvModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RwkvPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.util import safe_repr
from transformers import AutoTokenizer, RwkvConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
RwkvForCausalLM,
RwkvModel,
)
class RwkvModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_token_type_ids=False,
use_input_mask=True,
use_labels=True,
use_mc_token_ids=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_token_type_ids = use_token_type_ids
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def get_large_model_config(self):
return RwkvConfig.from_pretrained("sgugger/rwkv-4-pile-7b")
def prepare_config_and_inputs(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
mc_token_ids = None
if self.use_mc_token_ids:
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(
gradient_checkpointing=gradient_checkpointing,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
return (
config,
input_ids,
input_mask,
None,
token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
)
def get_config(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
return RwkvConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=self.intermediate_size,
activation_function=self.hidden_act,
resid_pdrop=self.hidden_dropout_prob,
attn_pdrop=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 300
return config
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
config.output_hidden_states = True
model = RwkvModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = RwkvForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = RwkvModel(config=config)
model.to(torch_device)
model.eval()
outputs = model(input_ids)
output_whole = outputs.last_hidden_state
outputs = model(input_ids[:, :2])
output_one = outputs.last_hidden_state
# Using the state computed on the first inputs, we will get the same output
outputs = model(input_ids[:, 2:], state=outputs.state)
output_two = outputs.last_hidden_state
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = RwkvForCausalLM(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
@require_torch
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RwkvModel,
"text-generation": RwkvForCausalLM,
}
if is_torch_available()
else {}
)
# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
fx_compatible = False
test_missing_keys = False
test_model_parallel = False
test_pruning = False
test_head_masking = False # Rwkv does not support head masking
def setUp(self):
self.model_tester = RwkvModelTester(self)
self.config_tester = ConfigTester(
self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)
def assertInterval(self, member, container, msg=None):
r"""
Simple utility function to check if a member is inside an interval.
"""
if isinstance(member, torch.Tensor):
max_value, min_value = member.max().item(), member.min().item()
elif isinstance(member, list) or isinstance(member, tuple):
max_value, min_value = max(member), min(member)
if not isinstance(container, list):
raise TypeError("container should be a list or tuple")
elif len(container) != 2:
raise ValueError("container should have 2 elements")
expected_min, expected_max = container
is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max)
if not is_inside_interval:
standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container))
self.fail(self._formatMessage(msg, standardMsg))
def test_config(self):
self.config_tester.run_common_tests()
def test_rwkv_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_rwkv_model(*config_and_inputs)
def test_rwkv_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causl_lm(*config_and_inputs)
def test_state_equivalency(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config=config)
for name, param in model.named_parameters():
if "time_decay" in name:
if param.requires_grad:
self.assertTrue(param.data.max().item() == 3.0)
self.assertTrue(param.data.min().item() == -5.0)
elif "time_first" in name:
if param.requires_grad:
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
elif any([x in name for x in ["time_mix_key", "time_mix_receptance"]]):
if param.requires_grad:
self.assertInterval(
param.data,
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif "time_mix_value" in name:
if param.requires_grad:
self.assertInterval(
param.data,
[0.0, 1.3],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the attention outputs of Rwkv are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
batch_size = inputs["input_ids"].shape[0]
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
batch_size = inputs["input_ids"].shape[0]
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[batch_size, seq_len, config.hidden_size],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
batch_size = inputs["input_ids"].shape[0]
with torch.no_grad():
outputs = model(**inputs)
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[batch_size, seq_len, config.hidden_size],
)
@slow
def test_model_from_pretrained(self):
for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = RwkvModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@slow
class RWKVIntegrationTests(unittest.TestCase):
def setUp(self):
self.model_id = "RWKV/rwkv-4-169m-pile"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
def test_simple_generate(self):
expected_output = "Hello my name is Jasmine and I am a newbie to the"
model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device)
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
def test_simple_generate_bf16(self):
expected_output = "Hello my name is Jasmine and I am a newbie to the"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
model = RwkvForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
...@@ -93,15 +93,20 @@ config_common_kwargs = { ...@@ -93,15 +93,20 @@ config_common_kwargs = {
class ConfigTester(object): class ConfigTester(object):
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs): def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
self.parent = parent self.parent = parent
self.config_class = config_class self.config_class = config_class
self.has_text_modality = has_text_modality self.has_text_modality = has_text_modality
self.inputs_dict = kwargs self.inputs_dict = kwargs
self.common_properties = common_properties
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
config = self.config_class(**self.inputs_dict) config = self.config_class(**self.inputs_dict)
common_properties = ["hidden_size", "num_attention_heads", "num_hidden_layers"] common_properties = (
["hidden_size", "num_attention_heads", "num_hidden_layers"]
if self.common_properties is None
else self.common_properties
)
# Add common fields for text models # Add common fields for text models
if self.has_text_modality: if self.has_text_modality:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment