Unverified Commit e1c3ac25 authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

Add Phi-1 and Phi-1_5 (#26170)

* only dir not even init

* init

* tokenizer removed and reference of codegen added

* modeling file updated a lot remaining app_rotary_emb

* conversion script done

* conversion script fixed, a lot of factoring done and most tests pass

* added token_clf and extractive_QA_head

* integration tests pass

* flash attn tests pass!

* config done

* more docs in modeling file

* some style fix

* style and others

* doc test error fix

* more doc fix

* some attention fixes

* most fixes

* style and other fixes

* docs fix and config

* doc fix

* some comments

* conversion script updated

* conversion script updated

* Revert "conversion script updated"

This reverts commit e92378c54084ec0747041b113083d1746ecb6c7f.

* final comments

* add Phi to language_modeling.md

* edit phi.md file

* rebase and fix

* removed phi-1.5 example

* changed model_type from 'phi'->'mixformer-sequential'

* small change

* small change

* revert \small change

* changed mixformer-sequential->phi

* small change

* added phi-1.5 example instead of phi-1

* doc test might pass now

* rebase and small change

* added the dropout layer

* more fixes

* modified .md file

* very very small doc change
parent 00dc8562
# coding=utf-8
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Phi model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"susnato/phi-1_dev": "https://huggingface.co/susnato/phi-1_dev/resolve/main/config.json",
"susnato/phi-1_5_dev": "https://huggingface.co/susnato/phi-1_5_dev/resolve/main/config.json",
}
class PhiConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Phi
[susnato/phi-1_dev](https://huggingface.co/susnato/phi-1_dev).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 51200):
Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PhiModel`].
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
resid_pdrop (`float`, *optional*, defaults to 0.0):
Dropout probability for mlp outputs.
embd_pdrop (`int`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio after computing the attention scores.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
is an experimental feature, subject to breaking API changes in future versions.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding.
qk_layernorm (`bool`, *optional*, defaults to `False`):
Whether or not to normalize the Queries and Keys after projecting the hidden states
bos_token_id (`int`, *optional*, defaults to 1):
Denotes beginning of sequences token id.
eos_token_id (`int`, *optional*, defaults to 2):
Denotes end of sequences token id.
Example:
```python
>>> from transformers import PhiModel, PhiConfig
>>> # Initializing a Phi-1 style configuration
>>> configuration = PhiConfig.from_pretrained("susnato/phi-1_dev")
>>> # Initializing a model from the configuration
>>> model = PhiModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phi"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=51200,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act="gelu_new",
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.5,
qk_layernorm=False,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.qk_layernorm = qk_layernorm
self._rope_scaling_validation()
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
# coding=utf-8
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Weights conversion script for Phi
This script downloads both Phi-1 and Phi-1.5 checkpoints to "checkpoint_path" and then converts the weights to
HugfgingFace model's format and saves them in "pytorch_dump_folder_path".
"""
import argparse
import gc
import os
import torch
from huggingface_hub import hf_hub_download
from transformers import PhiConfig, PhiForCausalLM
_MODELS = {
"microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/blob/main/pytorch_model.bin",
"microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/blob/main/pytorch_model.bin",
}
PHI_MAPPING = {
"layers.0.wte.weight": "model.embed_tokens.weight",
"layers.25.linear.bias": "lm_head.bias",
"layers.25.linear.weight": "lm_head.weight",
"layers.25.ln.bias": "model.final_layernorm.bias",
"layers.25.ln.weight": "model.final_layernorm.weight",
"layers": "model.layers",
"ln": "input_layernorm",
"mixer": "self_attn",
"Wqkv": "query_key_value",
"out_proj": "dense",
}
def convert_weights(original_weights, mapping, config):
converted_weights = {}
original_weights_keys = sorted(original_weights.keys())
# we change names (1-24) -> layers(0-23) for Phi model layers
range_change = {
f"layers.{k}.": f"layers.{v}."
for k, v in zip(range(1, config.num_hidden_layers + 1), range(0, config.num_hidden_layers))
}
mapping.update(**range_change)
for original_weights_key in original_weights_keys:
new_key = original_weights_key
if "rotary_emb" in new_key:
continue
if "Wqkv" in new_key:
if "weight" in new_key:
weight = original_weights[new_key]
weights_shape = weight.shape
weight = (
weight.view(3, config.num_attention_heads, -1, config.hidden_size)
.transpose(0, 1)
.reshape(*weights_shape)
)
original_weights[new_key] = weight
elif "bias" in new_key:
bias = original_weights[new_key]
bias_shape = bias.shape
bias = bias.view(3, config.num_attention_heads, -1).transpose(0, 1).reshape(*bias_shape)
original_weights[new_key] = bias
for k, v in mapping.items():
if k in new_key:
new_key = new_key.replace(k, v)
converted_weights[new_key] = original_weights.pop(original_weights_key)
return converted_weights
def _download(url: str, root: str):
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
filename = f"{url.split('/')[-1]}"
hf_hub_download(
repo_id=repo_id,
filename=filename,
force_filename=root,
local_dir_use_symlinks=False,
)
def convert_phi_weights(checkpoint_path, pytorch_dump_folder_path, use_cuda, save_weights_directly):
device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
for each_model_name, each_model_url in _MODELS.items():
converted_checkpoint = {}
model_path = os.path.join(checkpoint_path, each_model_name + "_" + each_model_url.split("/")[-1])
if not os.path.exists(model_path):
print(f"\n{each_model_name} was not found! Downloading it to {model_path}")
_download(url=each_model_url, root=model_path)
model_checkpoint = torch.load(model_path, map_location=device)
model_type = each_model_name.split("/")[1] # phi-1 or phi-1_5
config = PhiConfig.from_pretrained(f"susnato/{model_type}_dev")
# Converting the weights
converted_checkpoint.update(**convert_weights(model_checkpoint, PHI_MAPPING, config))
# Save either the whole model or the converted weights
if save_weights_directly:
save_weights_path = os.path.join(
pytorch_dump_folder_path, each_model_name.split("/")[-1] + "_" + each_model_url.split("/")[-1]
)
torch.save(converted_checkpoint, save_weights_path)
print(f"Model weights saved at {save_weights_path}!")
else:
model = PhiForCausalLM(config).to(device)
model.load_state_dict(converted_checkpoint, strict=True)
save_model_path = os.path.join(pytorch_dump_folder_path, model_type)
model.save_pretrained(save_model_path)
print(f"Model saved at {save_model_path}!")
# release GPU memory for the 2nd model if cuda was used.
del config, model
# release GPU memory for the 2nd model if cuda was used.
del model_checkpoint, converted_checkpoint
if use_cuda:
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument(
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model. (Please enter full path)",
)
parser.add_argument(
"--use_cuda",
default=False,
type=bool,
help="Whether to load the weights on GPU during conversion or not, False by default",
)
parser.add_argument(
"--save_weights_directly",
default=True,
type=bool,
help="Whether to save the weights directly after conversion or load the weight to the Phi model and then save "
"the Phi model along with weights. True by default",
)
args = parser.parse_args()
convert_phi_weights(args.checkpoint_path, args.pytorch_dump_folder_path, args.use_cuda, args.save_weights_directly)
This diff is collapsed.
......@@ -6172,6 +6172,44 @@ class PersimmonPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = None
class PhiForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhiForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhiForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhiModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhiPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Phi model. """
import unittest
from transformers import PhiConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
PhiForCausalLM,
PhiForSequenceClassification,
PhiForTokenClassification,
PhiModel,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Phi
class PhiModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return PhiConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = PhiModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = PhiModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = PhiForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = PhiForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(PhiModel, PhiForCausalLM, PhiForSequenceClassification, PhiForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (PhiForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": PhiModel,
"text-classification": PhiForSequenceClassification,
"text-generation": PhiForCausalLM,
"token-classification": PhiForTokenClassification,
"zero-shot": PhiForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phi
def setUp(self):
self.model_tester = PhiModelTester(self)
self.config_tester = ConfigTester(self, config_class=PhiConfig, hidden_size=37)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phi,llama->phi
def test_phi_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = PhiForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phi,llama->phi
def test_phi_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "single_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = PhiForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phi,llama->phi
def test_phi_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = PhiForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@slow
@require_torch
class PhiIntegrationTest(unittest.TestCase):
def test_model_phi_1_logits(self):
input_ids = {
"input_ids": torch.tensor(
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
)
}
model = PhiForCausalLM.from_pretrained("susnato/phi-1_dev").to(torch_device)
model.eval()
output = model(**input_ids).logits
# fmt: off
EXPECTED_OUTPUT = torch.tensor([[2.2671, 6.7684, -2.0107, -1.2440, -1.5335, -2.3828, 6.9186, 6.4245, 3.1548, 0.9998, 0.0760, 4.4653, 4.9857, 4.2956, 1.2308, -1.4178, 0.1361, 0.5191, -0.5699, -2.2201, -3.0750, -3.9600, -4.5936, -3.7394, -2.7777, 6.1874, -0.4148, -1.5684, -0.5967, 0.2395], [1.7004, 4.0383, 0.0546, 0.4530, -0.3619, -0.9021, 1.8355, 1.3587, 1.2406, 2.5775, -0.8834, 5.1910, 4.2565, 4.1406, 3.0752, -0.9099, 1.1595, 0.0264, 0.3243, -1.1803, -1.3945, -2.1406, -3.9939, -1.4438, -2.9546, 3.9204, 1.0851, -1.0598, -1.7819, -0.4827]]).to(torch_device)
# fmt: on
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-4, rtol=1e-4))
def test_model_phi_1_5_logits(self):
input_ids = {
"input_ids": torch.tensor(
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
)
}
model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev").to(torch_device)
model.eval()
output = model(**input_ids).logits
# fmt: off
EXPECTED_OUTPUT = torch.tensor([[12.2922, 13.3507, 8.6963, 9.1355, 9.3502, 9.2667, 14.2027, 13.1363, 13.5446, 11.1337, 9.9279, 16.7195, 13.0768, 14.9141, 11.9965, 8.0233, 10.3129, 10.6118, 10.0204, 9.3827, 8.8344, 8.2806, 8.0153, 8.0540, 7.0964, 16.5743, 11.1256, 9.6987, 11.4770, 10.5440], [12.3323, 14.6050, 8.9986, 8.1580, 9.5654, 6.6728, 12.5966, 12.6662, 12.2784, 11.7522, 8.2039, 16.3102, 11.2203, 13.6088, 12.0125, 9.1021, 9.8216, 10.0987, 9.0926, 8.4260, 8.8009, 7.6547, 6.8075, 7.7881, 7.4501, 15.7451, 10.5053, 8.3129, 10.0027, 9.2612]]).to(torch_device)
# fmt: on
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-4, rtol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment