Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig): ...@@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
``` ```"""
"""
model_type = "chatglm" model_type = "chatglm"
def __init__( def __init__(
self, self,
vocab_size=130528, vocab_size=130528,
hidden_size=4096, hidden_size=4096,
num_layers=28, num_layers=28,
num_attention_heads=32, num_attention_heads=32,
layernorm_epsilon=1e-5, layernorm_epsilon=1e-5,
use_cache=True, use_cache=True,
bos_token_id=130004, bos_token_id=130004,
eos_token_id=130005, eos_token_id=130005,
mask_token_id=130000, mask_token_id=130000,
gmask_token_id=130001, gmask_token_id=130001,
pad_token_id=3, pad_token_id=3,
max_sequence_length=2048, max_sequence_length=2048,
inner_hidden_size=16384, inner_hidden_size=16384,
position_encoding_2d=True, position_encoding_2d=True,
quantization_bit=0, quantization_bit=0,
pre_seq_len=None, pre_seq_len=None,
prefix_projection=False, prefix_projection=False,
**kwargs **kwargs,
): ):
self.num_layers = num_layers self.num_layers = num_layers
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig): ...@@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig):
self.pre_seq_len = pre_seq_len self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection self.prefix_projection = prefix_projection
super().__init__( super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
\ No newline at end of file
...@@ -4,41 +4,40 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/mo ...@@ -4,41 +4,40 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/mo
""" PyTorch ChatGLM model. """ """ PyTorch ChatGLM model. """
import math
import copy import copy
import math
import os import os
import warnings
import re import re
import sys import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPast,
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import (
from transformers.generation.logits_process import LogitsProcessor add_code_sample_docstrings,
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_chatglm import ChatGLMConfig from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
if sys.platform != 'darwin': if sys.platform != "darwin":
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_cpu(True)
...@@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): ...@@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any( if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name for n in name
): ):
logger.info(f"Skipping {'/'.join(name)}") logger.info(f"Skipping {'/'.join(name)}")
continue continue
...@@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): ...@@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
array = np.transpose(array) array = np.transpose(array)
try: try:
assert ( assert (
pointer.shape == array.shape pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
...@@ -153,7 +152,7 @@ class PrefixEncoder(torch.nn.Module): ...@@ -153,7 +152,7 @@ class PrefixEncoder(torch.nn.Module):
self.trans = torch.nn.Sequential( self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(), torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
) )
else: else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
...@@ -170,8 +169,7 @@ class PrefixEncoder(torch.nn.Module): ...@@ -170,8 +169,7 @@ class PrefixEncoder(torch.nn.Module):
@torch.jit.script @torch.jit.script
def gelu_impl(x): def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
(1.0 + 0.044715 * x * x)))
def gelu(x): def gelu(x):
...@@ -181,21 +179,22 @@ def gelu(x): ...@@ -181,21 +179,22 @@ def gelu(x):
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False): def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__() super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half() inv_freq = inv_freq.half()
self.learnable = learnable self.learnable = learnable
if learnable: if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq) self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None self.max_seq_len_cached = None
else: else:
self.register_buffer('inv_freq', inv_freq) self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None self.max_seq_len_cached = None
self.cos_cached = None self.cos_cached = None
self.sin_cached = None self.sin_cached = None
self.precision = precision self.precision = precision
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, def _load_from_state_dict(
error_msgs): self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
pass pass
def forward(self, x, seq_dim=1, seq_len=None): def forward(self, x, seq_dim=1, seq_len=None):
...@@ -204,7 +203,7 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -204,7 +203,7 @@ class RotaryEmbedding(torch.nn.Module):
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16: if self.precision == torch.bfloat16:
...@@ -230,30 +229,31 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -230,30 +229,31 @@ class RotaryEmbedding(torch.nn.Module):
def rotate_half(x): def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script @torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) position_id, sin.squeeze(1)
).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k return q, k
def attention_fn( def attention_fn(
self, self,
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attention_mask,
hidden_size_per_partition, hidden_size_per_partition,
layer_id, layer_id,
layer_past=None, layer_past=None,
scaling_attention_score=True, scaling_attention_score=True,
use_cache=False, use_cache=False,
): ):
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0], layer_past[1]
...@@ -285,7 +285,9 @@ def attention_fn( ...@@ -285,7 +285,9 @@ def attention_fn(
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_result = torch.zeros( matmul_result = torch.zeros(
1, 1, 1, 1,
1,
1,
dtype=query_layer.dtype, dtype=query_layer.dtype,
device=query_layer.device, device=query_layer.device,
) )
...@@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs): ...@@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
class SelfAttention(torch.nn.Module): class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads, def __init__(
layer_id, hidden_size_per_attention_head=None, bias=True, self,
params_dtype=torch.float, position_encoding_2d=True, empty_init=True): hidden_size,
num_attention_heads,
layer_id,
hidden_size_per_attention_head=None,
bias=True,
params_dtype=torch.float,
position_encoding_2d=True,
empty_init=True,
):
if empty_init: if empty_init:
init_method = skip_init init_method = skip_init
else: else:
...@@ -410,8 +420,7 @@ class SelfAttention(torch.nn.Module): ...@@ -410,8 +420,7 @@ class SelfAttention(torch.nn.Module):
attention_scores.masked_fill_(attention_mask, -10000.0) attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores return attention_scores
def split_tensor_along_last_dim(self, tensor, num_partitions, def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
contiguous_split_chunks=False):
"""Split a tensor along its last dimension. """Split a tensor along its last dimension.
Arguments: Arguments:
tensor: input tensor. tensor: input tensor.
...@@ -431,14 +440,14 @@ class SelfAttention(torch.nn.Module): ...@@ -431,14 +440,14 @@ class SelfAttention(torch.nn.Module):
return tensor_list return tensor_list
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids, position_ids,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
layer_id, layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
): ):
""" """
hidden_states: [seq_len, batch, hidden_size] hidden_states: [seq_len, batch, hidden_size]
...@@ -462,8 +471,10 @@ class SelfAttention(torch.nn.Module): ...@@ -462,8 +471,10 @@ class SelfAttention(torch.nn.Module):
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ position_ids, block_position_ids = (
position_ids[:, 1, :].transpose(0, 1).contiguous() position_ids[:, 0, :].transpose(0, 1).contiguous(),
position_ids[:, 1, :].transpose(0, 1).contiguous(),
)
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
...@@ -484,7 +495,7 @@ class SelfAttention(torch.nn.Module): ...@@ -484,7 +495,7 @@ class SelfAttention(torch.nn.Module):
hidden_size_per_partition=self.hidden_size_per_partition, hidden_size_per_partition=self.hidden_size_per_partition,
layer_id=layer_id, layer_id=layer_id,
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache use_cache=use_cache,
) )
output = self.dense(context_layer) output = self.dense(context_layer)
...@@ -509,8 +520,16 @@ class GEGLU(torch.nn.Module): ...@@ -509,8 +520,16 @@ class GEGLU(torch.nn.Module):
class GLU(torch.nn.Module): class GLU(torch.nn.Module):
def __init__(self, hidden_size, inner_hidden_size=None, def __init__(
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): self,
hidden_size,
inner_hidden_size=None,
layer_id=None,
bias=True,
activation_func=gelu,
params_dtype=torch.float,
empty_init=True,
):
super(GLU, self).__init__() super(GLU, self).__init__()
if empty_init: if empty_init:
init_method = skip_init init_method = skip_init
...@@ -557,19 +576,19 @@ class GLU(torch.nn.Module): ...@@ -557,19 +576,19 @@ class GLU(torch.nn.Module):
class GLMBlock(torch.nn.Module): class GLMBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
hidden_size, hidden_size,
num_attention_heads, num_attention_heads,
layernorm_epsilon, layernorm_epsilon,
layer_id, layer_id,
inner_hidden_size=None, inner_hidden_size=None,
hidden_size_per_attention_head=None, hidden_size_per_attention_head=None,
layernorm=LayerNorm, layernorm=LayerNorm,
use_bias=True, use_bias=True,
params_dtype=torch.float, params_dtype=torch.float,
num_layers=28, num_layers=28,
position_encoding_2d=True, position_encoding_2d=True,
empty_init=True empty_init=True,
): ):
super(GLMBlock, self).__init__() super(GLMBlock, self).__init__()
# Set output layer initialization if not provided. # Set output layer initialization if not provided.
...@@ -590,7 +609,7 @@ class GLMBlock(torch.nn.Module): ...@@ -590,7 +609,7 @@ class GLMBlock(torch.nn.Module):
bias=use_bias, bias=use_bias,
params_dtype=params_dtype, params_dtype=params_dtype,
position_encoding_2d=self.position_encoding_2d, position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init empty_init=empty_init,
) )
# Layernorm on the input data. # Layernorm on the input data.
...@@ -605,18 +624,18 @@ class GLMBlock(torch.nn.Module): ...@@ -605,18 +624,18 @@ class GLMBlock(torch.nn.Module):
bias=use_bias, bias=use_bias,
layer_id=layer_id, layer_id=layer_id,
params_dtype=params_dtype, params_dtype=params_dtype,
empty_init=empty_init empty_init=empty_init,
) )
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids, position_ids,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
layer_id, layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
): ):
""" """
hidden_states: [seq_len, batch, hidden_size] hidden_states: [seq_len, batch, hidden_size]
...@@ -635,7 +654,7 @@ class GLMBlock(torch.nn.Module): ...@@ -635,7 +654,7 @@ class GLMBlock(torch.nn.Module):
layer_id=layer_id, layer_id=layer_id,
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions output_attentions=output_attentions,
) )
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
...@@ -702,10 +721,15 @@ class ChatGLMPreTrainedModel(PreTrainedModel): ...@@ -702,10 +721,15 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
for i, context_length in enumerate(context_lengths): for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i] position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat(( block_position_ids = [
torch.zeros(context_length, dtype=torch.long, device=device), torch.cat(
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 (
)) for context_length in context_lengths] torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
)
)
for context_length in context_lengths
]
block_position_ids = torch.stack(block_position_ids, dim=0) block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1) position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else: else:
...@@ -823,9 +847,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -823,9 +847,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.prefix_projection = config.prefix_projection self.prefix_projection = config.prefix_projection
self.word_embeddings = init_method( self.word_embeddings = init_method(
torch.nn.Embedding, torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
dtype=self.params_dtype
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
...@@ -841,12 +863,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -841,12 +863,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_bias=True, use_bias=True,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d, position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init empty_init=empty_init,
) )
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
[get_layer(layer_id) for layer_id in range(self.num_layers)]
)
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
...@@ -876,7 +896,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -876,7 +896,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.pre_seq_len, self.pre_seq_len,
self.num_layers * 2, self.num_layers * 2,
self.num_attention_heads, self.num_attention_heads,
self.hidden_size // self.num_attention_heads self.hidden_size // self.num_attention_heads,
) )
# seq_len, b, nh, hidden_size # seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values) past_key_values = self.dropout(past_key_values)
...@@ -891,18 +911,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -891,18 +911,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -931,17 +950,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -931,17 +950,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if past_key_values is None: if past_key_values is None:
if self.pre_seq_len is not None: if self.pre_seq_len is not None:
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, past_key_values = self.get_prompt(
dtype=inputs_embeds.dtype) batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
)
else: else:
past_key_values = tuple([None] * len(self.layers)) past_key_values = tuple([None] * len(self.layers))
if attention_mask is None: if attention_mask is None:
attention_mask = self.get_masks( attention_mask = self.get_masks(input_ids, device=input_ids.device)
input_ids,
device=input_ids.device
)
if position_ids is None: if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
...@@ -955,15 +971,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -955,15 +971,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_gmasks.append(use_gmask) use_gmasks.append(use_gmask)
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
input_ids, input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
mask_positions=mask_positions,
device=input_ids.device,
use_gmasks=use_gmasks
) )
if self.pre_seq_len is not None and attention_mask is not None: if self.pre_seq_len is not None and attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
attention_mask.device) attention_mask.device
)
prefix_attention_mask = (prefix_attention_mask < 0.5).bool() prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
...@@ -980,7 +994,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -980,7 +994,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i] layer_past = past_key_values[i]
...@@ -994,7 +1007,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -994,7 +1007,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
torch.tensor(i), torch.tensor(i),
layer_past, layer_past,
use_cache, use_cache,
output_attentions output_attentions,
) )
else: else:
layer_ret = layer( layer_ret = layer(
...@@ -1004,7 +1017,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): ...@@ -1004,7 +1017,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
layer_id=torch.tensor(i), layer_id=torch.tensor(i),
layer_past=layer_past, layer_past=layer_past,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions output_attentions=output_attentions,
) )
hidden_states = layer_ret[0] hidden_states = layer_ret[0]
...@@ -1049,13 +1062,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1049,13 +1062,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
self.transformer = ChatGLMModel(config, empty_init=empty_init) self.transformer = ChatGLMModel(config, empty_init=empty_init)
self.lm_head = init_method( self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
nn.Linear,
config.hidden_size,
config.vocab_size,
bias=False,
dtype=torch.half
)
self.config = config self.config = config
...@@ -1087,32 +1094,29 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1087,32 +1094,29 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask = model_kwargs["attention_mask"] attention_mask = model_kwargs["attention_mask"]
if attention_mask is not None and attention_mask.dtype == torch.bool: if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = torch.cat( attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
)
new_attention_mask = attention_mask[:, :, -1:].clone() new_attention_mask = attention_mask[:, :, -1:].clone()
new_attention_mask[..., -1] = False new_attention_mask[..., -1] = False
model_kwargs["attention_mask"] = torch.cat( model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
[attention_mask, new_attention_mask], dim=2
)
# update position ids # update position ids
if "position_ids" in model_kwargs: if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"] position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone() new_position_id = position_ids[..., -1:].clone()
new_position_id[:, 1, :] += 1 new_position_id[:, 1, :] += 1
model_kwargs["position_ids"] = torch.cat( model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
[position_ids, new_position_id], dim=-1
)
return model_kwargs return model_kwargs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None, past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
**kwargs **kwargs,
) -> dict: ) -> dict:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
...@@ -1137,11 +1141,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1137,11 +1141,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
if self.position_encoding_2d: if self.position_encoding_2d:
position_ids = torch.tensor( position_ids = torch.tensor(
[[mask_position, seq_length - context_length] for mask_position, context_length in [
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) [mask_position, seq_length - context_length]
for mask_position, context_length in zip(mask_positions, context_lengths)
],
dtype=torch.long,
device=input_ids.device,
).unsqueeze(-1)
else: else:
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, position_ids = torch.tensor(
device=input_ids.device).unsqueeze(-1) [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
).unsqueeze(-1)
if past is None: if past is None:
past = past_key_values past = past_key_values
...@@ -1149,44 +1159,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1149,44 +1159,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
"input_ids": last_token, "input_ids": last_token,
"past_key_values": past, "past_key_values": past,
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask "attention_mask": attention_mask,
} }
else: else:
if attention_mask is not None and attention_mask.dtype != torch.bool: if attention_mask is not None and attention_mask.dtype != torch.bool:
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
attention_mask = None attention_mask = None
if attention_mask is None: if attention_mask is None:
attention_mask = self.get_masks( attention_mask = self.get_masks(input_ids, device=input_ids.device)
input_ids,
device=input_ids.device
)
if position_ids is None: if position_ids is None:
position_ids = self.get_position_ids( position_ids = self.get_position_ids(
input_ids, input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
device=input_ids.device,
mask_positions=mask_positions,
use_gmasks=use_gmasks
) )
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"past_key_values": past, "past_key_values": past,
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask "attention_mask": attention_mask,
} }
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1235,7 +1239,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1235,7 +1239,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@staticmethod @staticmethod
def _reorder_cache( def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
...@@ -1268,15 +1272,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1268,15 +1272,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
return response return response
@torch.no_grad() @torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, def chat(
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
num_beams=1,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs,
):
if history is None: if history is None:
history = [] history = []
if logits_processor is None: if logits_processor is None:
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor()) logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, gen_kwargs = {
"temperature": temperature, "logits_processor": logits_processor, **kwargs} "max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if not history: if not history:
prompt = query prompt = query
else: else:
...@@ -1287,22 +1309,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1287,22 +1309,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
inputs = tokenizer([prompt], return_tensors="pt") inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
outputs = self.generate(**inputs, **gen_kwargs) outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
response = self.process_response(response) response = self.process_response(response)
history = history + [(query, response)] history = history + [(query, response)]
return response, history return response, history
@torch.no_grad() @torch.no_grad()
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, def stream_chat(
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs,
):
if history is None: if history is None:
history = [] history = []
if logits_processor is None: if logits_processor is None:
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor()) logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, gen_kwargs = {
"temperature": temperature, "logits_processor": logits_processor, **kwargs} "max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if not history: if not history:
prompt = query prompt = query
else: else:
...@@ -1313,7 +1351,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1313,7 +1351,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
inputs = tokenizer([prompt], return_tensors="pt") inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
for outputs in self.stream_generate(**inputs, **gen_kwargs): for outputs in self.stream_generate(**inputs, **gen_kwargs):
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
response = self.process_response(response) response = self.process_response(response)
new_history = history + [(query, response)] new_history = history + [(query, response)]
...@@ -1321,13 +1359,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): ...@@ -1321,13 +1359,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@torch.no_grad() @torch.no_grad()
def stream_generate( def stream_generate(
self, self,
input_ids, input_ids,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs, **kwargs,
): ):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
......
...@@ -16,9 +16,9 @@ except ImportError: ...@@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def _prepare_logits_processor(top_k: Optional[int] = None, def _prepare_logits_processor(
top_p: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
temperature: Optional[float] = None) -> LogitsProcessorList: ) -> LogitsProcessorList:
processor_list = LogitsProcessorList() processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0: if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature)) processor_list.append(TemperatureLogitsWarper(temperature))
...@@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: ...@@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0 return unfinished_sequences.max() == 0
def _sample(model: Actor, def _sample(
input_ids: torch.Tensor, model: Actor,
max_length: int, input_ids: torch.Tensor,
early_stopping: bool = False, max_length: int,
eos_token_id: Optional[int] = None, early_stopping: bool = False,
pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
top_k: Optional[int] = None, pad_token_id: Optional[int] = None,
top_p: Optional[float] = None, top_k: Optional[int] = None,
temperature: Optional[float] = None, top_p: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, temperature: Optional[float] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor: update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
if input_ids.size(1) >= max_length: if input_ids.size(1) >= max_length:
return input_ids return input_ids
...@@ -56,11 +58,12 @@ def _sample(model: Actor, ...@@ -56,11 +58,12 @@ def _sample(model: Actor,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length): for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \ model_inputs = (
if prepare_inputs_fn is not None else {'input_ids': input_ids} prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
)
outputs = model(**model_inputs) outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :] next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution # pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits) next_token_logits = logits_processor(input_ids, next_token_logits)
# sample # sample
...@@ -90,20 +93,22 @@ def _sample(model: Actor, ...@@ -90,20 +93,22 @@ def _sample(model: Actor,
@torch.no_grad() @torch.no_grad()
def generate(model: Actor, def generate(
input_ids: torch.Tensor, model: Actor,
max_length: int, input_ids: torch.Tensor,
num_beams: int = 1, max_length: int,
do_sample: bool = True, num_beams: int = 1,
early_stopping: bool = False, do_sample: bool = True,
eos_token_id: Optional[int] = None, early_stopping: bool = False,
pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
top_k: Optional[int] = None, pad_token_id: Optional[int] = None,
top_p: Optional[float] = None, top_k: Optional[int] = None,
temperature: Optional[float] = None, top_p: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, temperature: Optional[float] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor: update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens. """Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args: Args:
...@@ -121,26 +126,28 @@ def generate(model: Actor, ...@@ -121,26 +126,28 @@ def generate(model: Actor,
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
""" """
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = ((num_beams == 1) and do_sample is True) is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = ((num_beams > 1) and do_sample is False) is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode: if is_greedy_gen_mode:
# run greedy search # run greedy search
raise NotImplementedError raise NotImplementedError
elif is_sample_gen_mode: elif is_sample_gen_mode:
# run sample # run sample
return _sample(model, return _sample(
input_ids, model,
max_length, input_ids,
early_stopping=early_stopping, max_length,
eos_token_id=eos_token_id, early_stopping=early_stopping,
pad_token_id=pad_token_id, eos_token_id=eos_token_id,
top_k=top_k, pad_token_id=pad_token_id,
top_p=top_p, top_k=top_k,
temperature=temperature, top_p=top_p,
prepare_inputs_fn=prepare_inputs_fn, temperature=temperature,
update_model_kwargs_fn=update_model_kwargs_fn, prepare_inputs_fn=prepare_inputs_fn,
**model_kwargs) update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
)
elif is_beam_gen_mode: elif is_beam_gen_mode:
raise NotImplementedError raise NotImplementedError
else: else:
......
...@@ -2,4 +2,4 @@ from .gpt_actor import GPTActor ...@@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] __all__ = ["GPTActor", "GPTCritic", "GPTRM"]
...@@ -18,13 +18,15 @@ class GPTActor(Actor): ...@@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer. lora_train_bias (str): Bias training strategy for the LoRa layer.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[GPT2Config] = None, pretrained: Optional[str] = None,
checkpoint: bool = False, config: Optional[GPT2Config] = None,
lora_rank: int = 0, checkpoint: bool = False,
lora_train_bias: str = 'none', lora_rank: int = 0,
**kwargs) -> None: lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained) model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -18,12 +18,14 @@ class GPTCritic(Critic): ...@@ -18,12 +18,14 @@ class GPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[GPT2Config] = None, pretrained: Optional[str] = None,
lora_rank: int = 0, config: Optional[GPT2Config] = None,
lora_train_bias: str = 'none', lora_rank: int = 0,
**kwargs) -> None: lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained) model = GPT2Model.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -18,11 +18,13 @@ class GPTRM(RewardModel): ...@@ -18,11 +18,13 @@ class GPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[GPT2Config] = None, pretrained: Optional[str] = None,
lora_rank: int = 0, config: Optional[GPT2Config] = None,
lora_train_bias: str = 'none') -> None: lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained) model = GPT2Model.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -2,4 +2,4 @@ from .llama_actor import LlamaActor ...@@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] __all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
from typing import Optional from typing import Optional
import torch from transformers import LlamaConfig, LlamaForCausalLM
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from ..base import Actor from ..base import Actor
...@@ -18,13 +17,14 @@ class LlamaActor(Actor): ...@@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[LlamaConfig] = None, pretrained: Optional[str] = None,
checkpoint: bool = False, config: Optional[LlamaConfig] = None,
lora_rank: int = 0, checkpoint: bool = False,
lora_train_bias: str = 'none') -> None: lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained) model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -17,13 +17,14 @@ class LlamaCritic(Critic): ...@@ -17,13 +17,14 @@ class LlamaCritic(Critic):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[LlamaConfig] = None, pretrained: Optional[str] = None,
lora_rank: int = 0, config: Optional[LlamaConfig] = None,
lora_train_bias: str = 'none', lora_rank: int = 0,
**kwargs) -> None: lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
from typing import Optional from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel from ..base import RewardModel
...@@ -17,12 +17,13 @@ class LlamaRM(RewardModel): ...@@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[LlamaConfig] = None, pretrained: Optional[str] = None,
lora_rank: int = 0, config: Optional[LlamaConfig] = None,
lora_train_bias: str = 'none') -> None: lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -8,8 +8,7 @@ import torch.nn.functional as F ...@@ -8,8 +8,7 @@ import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module): class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
"""
def __init__( def __init__(
self, self,
...@@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module):
bias: Optional[nn.Parameter], bias: Optional[nn.Parameter],
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0., lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True, merge_weights: bool = True,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
lora.LoRALayer.__init__(self, lora.LoRALayer.__init__(
r=r, self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
lora_alpha=lora_alpha, )
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
...@@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.weight.data = self.weight.data.T self.weight.data = self.weight.data.T
def reset_parameters(self): def reset_parameters(self):
if hasattr(self, 'lora_A'): if hasattr(self, "lora_A"):
# Initialize A with the default values for nn.Linear and set B to zero. # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True): def train(self, mode: bool = True):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.merged = False self.merged = False
def eval(self): def eval(self):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it # Merge the weights and mark it
if self.r > 0: if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A') delattr(self, "lora_A")
delattr(self, 'lora_B') delattr(self, "lora_B")
self.merged = True self.merged = True
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear return lora_linear
...@@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: ...@@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
_convert_to_lora_recursively(child, lora_rank) _convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module. """Convert a torch.nn.Module to a LoRA module.
Args: Args:
...@@ -140,7 +136,7 @@ class LoRAModule(nn.Module): ...@@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'. Defaults to 'none'.
""" """
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__() super().__init__()
self.lora_rank = lora_rank self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias self.lora_train_bias = lora_train_bias
......
...@@ -31,11 +31,13 @@ class PolicyLoss(nn.Module): ...@@ -31,11 +31,13 @@ class PolicyLoss(nn.Module):
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps = clip_eps
def forward(self, def forward(
log_probs: torch.Tensor, self,
old_log_probs: torch.Tensor, log_probs: torch.Tensor,
advantages: torch.Tensor, old_log_probs: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp() ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
...@@ -55,14 +57,16 @@ class ValueLoss(nn.Module): ...@@ -55,14 +57,16 @@ class ValueLoss(nn.Module):
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps = clip_eps
def forward(self, def forward(
values: torch.Tensor, self,
old_values: torch.Tensor, values: torch.Tensor,
reward: torch.Tensor, old_values: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2 surr1 = (values_clipped - reward) ** 2
surr2 = (values - reward)**2 surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2) loss = torch.max(surr1, surr2)
loss = loss.mean() loss = loss.mean()
return 0.5 * loss return 0.5 * loss
......
...@@ -2,4 +2,4 @@ from .opt_actor import OPTActor ...@@ -2,4 +2,4 @@ from .opt_actor import OPTActor
from .opt_critic import OPTCritic from .opt_critic import OPTCritic
from .opt_rm import OPTRM from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] __all__ = ["OPTActor", "OPTCritic", "OPTRM"]
...@@ -18,12 +18,14 @@ class OPTActor(Actor): ...@@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[OPTConfig] = None, pretrained: Optional[str] = None,
checkpoint: bool = False, config: Optional[OPTConfig] = None,
lora_rank: int = 0, checkpoint: bool = False,
lora_train_bias: str = 'none') -> None: lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained) model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -18,12 +18,14 @@ class OPTCritic(Critic): ...@@ -18,12 +18,14 @@ class OPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[OPTConfig] = None, pretrained: Optional[str] = None,
lora_rank: int = 0, config: Optional[OPTConfig] = None,
lora_train_bias: str = 'none', lora_rank: int = 0,
**kwargs) -> None: lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTModel.from_pretrained(pretrained) model = OPTModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -17,11 +17,13 @@ class OPTRM(RewardModel): ...@@ -17,11 +17,13 @@ class OPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
pretrained: Optional[str] = None, self,
config: Optional[OPTConfig] = None, pretrained: Optional[str] = None,
lora_rank: int = 0, config: Optional[OPTConfig] = None,
lora_train_bias: str = 'none') -> None: lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTModel.from_pretrained(pretrained) model = OPTModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -4,9 +4,9 @@ import torch ...@@ -4,9 +4,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
def _compute_approx_kl(log_probs: torch.Tensor, def _compute_approx_kl(
log_probs_base: torch.Tensor, log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute the approximate KL divergence between two distributions. Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html Schulman blog: http://joschu.net/blog/kl-approx.html
...@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor, ...@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
return approx_kl return approx_kl
def compute_reward(r: Union[torch.Tensor, float], def compute_reward(
kl_coef: float, r: Union[torch.Tensor, float],
log_probs: torch.Tensor, kl_coef: float,
log_probs_base: torch.Tensor, log_probs: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if kl_coef <= 0.0: if kl_coef <= 0.0:
return r return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
...@@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num ...@@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
Returns: Returns:
torch.Tensor: Action log probs. torch.Tensor: Action log probs.
""" """
logits = output['logits'] logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:] return log_probs[:, -num_actions:]
......
...@@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant ...@@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
from .utils import low_resource_init from .utils import low_resource_init
__all__ = [ __all__ = [
'llama_load_quant', "llama_load_quant",
'low_resource_init', "low_resource_init",
] ]
from .loader import load_quant from .loader import load_quant
__all__ = [ __all__ = [
'load_quant', "load_quant",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment