Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
```"""
model_type = "chatglm"
def __init__(
self,
vocab_size=130528,
hidden_size=4096,
num_layers=28,
num_attention_heads=32,
layernorm_epsilon=1e-5,
use_cache=True,
bos_token_id=130004,
eos_token_id=130005,
mask_token_id=130000,
gmask_token_id=130001,
pad_token_id=3,
max_sequence_length=2048,
inner_hidden_size=16384,
position_encoding_2d=True,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
self,
vocab_size=130528,
hidden_size=4096,
num_layers=28,
num_attention_heads=32,
layernorm_epsilon=1e-5,
use_cache=True,
bos_token_id=130004,
eos_token_id=130005,
mask_token_id=130000,
gmask_token_id=130001,
pad_token_id=3,
max_sequence_length=2048,
inner_hidden_size=16384,
position_encoding_2d=True,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs,
):
self.num_layers = num_layers
self.vocab_size = vocab_size
......@@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig):
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
\ No newline at end of file
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......@@ -4,41 +4,40 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/mo
""" PyTorch ChatGLM model. """
import math
import copy
import math
import os
import warnings
import re
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
if sys.platform != 'darwin':
if sys.platform != "darwin":
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
......@@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
......@@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
......@@ -153,7 +152,7 @@ class PrefixEncoder(torch.nn.Module):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
......@@ -170,8 +169,7 @@ class PrefixEncoder(torch.nn.Module):
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def gelu(x):
......@@ -181,21 +179,22 @@ def gelu(x):
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
self.register_buffer('inv_freq', inv_freq)
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
pass
def forward(self, x, seq_dim=1, seq_len=None):
......@@ -204,7 +203,7 @@ class RotaryEmbedding(torch.nn.Module):
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
......@@ -230,30 +229,31 @@ class RotaryEmbedding(torch.nn.Module):
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
position_id, sin.squeeze(1)
).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
hidden_size_per_partition,
layer_id,
layer_past=None,
scaling_attention_score=True,
use_cache=False,
self,
query_layer,
key_layer,
value_layer,
attention_mask,
hidden_size_per_partition,
layer_id,
layer_past=None,
scaling_attention_score=True,
use_cache=False,
):
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
......@@ -285,7 +285,9 @@ def attention_fn(
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_result = torch.zeros(
1, 1, 1,
1,
1,
1,
dtype=query_layer.dtype,
device=query_layer.device,
)
......@@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
layer_id, hidden_size_per_attention_head=None, bias=True,
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
def __init__(
self,
hidden_size,
num_attention_heads,
layer_id,
hidden_size_per_attention_head=None,
bias=True,
params_dtype=torch.float,
position_encoding_2d=True,
empty_init=True,
):
if empty_init:
init_method = skip_init
else:
......@@ -410,8 +420,7 @@ class SelfAttention(torch.nn.Module):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def split_tensor_along_last_dim(self, tensor, num_partitions,
contiguous_split_chunks=False):
def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
......@@ -431,14 +440,14 @@ class SelfAttention(torch.nn.Module):
return tensor_list
def forward(
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
......@@ -462,8 +471,10 @@ class SelfAttention(torch.nn.Module):
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
position_ids[:, 1, :].transpose(0, 1).contiguous()
position_ids, block_position_ids = (
position_ids[:, 0, :].transpose(0, 1).contiguous(),
position_ids[:, 1, :].transpose(0, 1).contiguous(),
)
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
......@@ -484,7 +495,7 @@ class SelfAttention(torch.nn.Module):
hidden_size_per_partition=self.hidden_size_per_partition,
layer_id=layer_id,
layer_past=layer_past,
use_cache=use_cache
use_cache=use_cache,
)
output = self.dense(context_layer)
......@@ -509,8 +520,16 @@ class GEGLU(torch.nn.Module):
class GLU(torch.nn.Module):
def __init__(self, hidden_size, inner_hidden_size=None,
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
def __init__(
self,
hidden_size,
inner_hidden_size=None,
layer_id=None,
bias=True,
activation_func=gelu,
params_dtype=torch.float,
empty_init=True,
):
super(GLU, self).__init__()
if empty_init:
init_method = skip_init
......@@ -557,19 +576,19 @@ class GLU(torch.nn.Module):
class GLMBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
layernorm_epsilon,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
layernorm=LayerNorm,
use_bias=True,
params_dtype=torch.float,
num_layers=28,
position_encoding_2d=True,
empty_init=True
self,
hidden_size,
num_attention_heads,
layernorm_epsilon,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
layernorm=LayerNorm,
use_bias=True,
params_dtype=torch.float,
num_layers=28,
position_encoding_2d=True,
empty_init=True,
):
super(GLMBlock, self).__init__()
# Set output layer initialization if not provided.
......@@ -590,7 +609,7 @@ class GLMBlock(torch.nn.Module):
bias=use_bias,
params_dtype=params_dtype,
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
empty_init=empty_init,
)
# Layernorm on the input data.
......@@ -605,18 +624,18 @@ class GLMBlock(torch.nn.Module):
bias=use_bias,
layer_id=layer_id,
params_dtype=params_dtype,
empty_init=empty_init
empty_init=empty_init,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
......@@ -635,7 +654,7 @@ class GLMBlock(torch.nn.Module):
layer_id=layer_id,
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
......@@ -702,10 +721,15 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat((
torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
)) for context_length in context_lengths]
block_position_ids = [
torch.cat(
(
torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
)
)
for context_length in context_lengths
]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else:
......@@ -823,9 +847,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.prefix_projection = config.prefix_projection
self.word_embeddings = init_method(
torch.nn.Embedding,
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
dtype=self.params_dtype
torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
)
self.gradient_checkpointing = False
......@@ -841,12 +863,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_bias=True,
params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
empty_init=empty_init,
)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(self.num_layers)]
)
self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
......@@ -876,7 +896,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.pre_seq_len,
self.num_layers * 2,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads
self.hidden_size // self.num_attention_heads,
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
......@@ -891,18 +911,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -931,17 +950,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if past_key_values is None:
if self.pre_seq_len is not None:
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
dtype=inputs_embeds.dtype)
past_key_values = self.get_prompt(
batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
)
else:
past_key_values = tuple([None] * len(self.layers))
if attention_mask is None:
attention_mask = self.get_masks(
input_ids,
device=input_ids.device
)
attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
......@@ -955,15 +971,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_gmasks.append(use_gmask)
position_ids = self.get_position_ids(
input_ids,
mask_positions=mask_positions,
device=input_ids.device,
use_gmasks=use_gmasks
input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
)
if self.pre_seq_len is not None and attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
attention_mask.device)
attention_mask.device
)
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
......@@ -980,7 +994,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i]
......@@ -994,7 +1007,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
torch.tensor(i),
layer_past,
use_cache,
output_attentions
output_attentions,
)
else:
layer_ret = layer(
......@@ -1004,7 +1017,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
layer_id=torch.tensor(i),
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
output_attentions=output_attentions,
)
hidden_states = layer_ret[0]
......@@ -1049,13 +1062,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
self.transformer = ChatGLMModel(config, empty_init=empty_init)
self.lm_head = init_method(
nn.Linear,
config.hidden_size,
config.vocab_size,
bias=False,
dtype=torch.half
)
self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
self.config = config
......@@ -1087,32 +1094,29 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask = model_kwargs["attention_mask"]
if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
)
new_attention_mask = attention_mask[:, :, -1:].clone()
new_attention_mask[..., -1] = False
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, new_attention_mask], dim=2
)
model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id[:, 1, :] += 1
model_kwargs["position_ids"] = torch.cat(
[position_ids, new_position_id], dim=-1
)
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
......@@ -1137,11 +1141,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
if self.position_encoding_2d:
position_ids = torch.tensor(
[[mask_position, seq_length - context_length] for mask_position, context_length in
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
[
[mask_position, seq_length - context_length]
for mask_position, context_length in zip(mask_positions, context_lengths)
],
dtype=torch.long,
device=input_ids.device,
).unsqueeze(-1)
else:
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
device=input_ids.device).unsqueeze(-1)
position_ids = torch.tensor(
[mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
).unsqueeze(-1)
if past is None:
past = past_key_values
......@@ -1149,44 +1159,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
"input_ids": last_token,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
else:
if attention_mask is not None and attention_mask.dtype != torch.bool:
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
attention_mask = None
if attention_mask is None:
attention_mask = self.get_masks(
input_ids,
device=input_ids.device
)
attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions,
use_gmasks=use_gmasks
input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
)
return {
"input_ids": input_ids,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1235,7 +1239,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
......@@ -1268,15 +1272,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
return response
@torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
num_beams=1,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs,
):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if not history:
prompt = query
else:
......@@ -1287,22 +1309,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs,
):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
gen_kwargs = {
"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if not history:
prompt = query
else:
......@@ -1313,7 +1351,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
for outputs in self.stream_generate(**inputs, **gen_kwargs):
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
......@@ -1321,13 +1359,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@torch.no_grad()
def stream_generate(
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
......
......@@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def _prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
def _prepare_logits_processor(
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
......@@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
def _sample(model: Actor,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def _sample(
model: Actor,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
......@@ -56,11 +58,12 @@ def _sample(model: Actor,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
if prepare_inputs_fn is not None else {'input_ids': input_ids}
model_inputs = (
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
)
outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :]
next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
......@@ -90,20 +93,22 @@ def _sample(model: Actor,
@torch.no_grad()
def generate(model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def generate(
model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
......@@ -121,26 +126,28 @@ def generate(model: Actor,
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return _sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
return _sample(
model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
)
elif is_beam_gen_mode:
raise NotImplementedError
else:
......
......@@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
__all__ = ["GPTActor", "GPTCritic", "GPTRM"]
......@@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
......
......@@ -18,12 +18,14 @@ class GPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
......
......@@ -18,11 +18,13 @@ class GPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
......
......@@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import Actor
......@@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
......
......@@ -17,13 +17,14 @@ class LlamaCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
......
from typing import Optional
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel
......@@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
......
......@@ -8,8 +8,7 @@ import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
......@@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module):
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
lora.LoRALayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.weight = weight
self.bias = bias
......@@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
if hasattr(self, "lora_A"):
# Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
......@@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
......@@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
......@@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
......@@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
_convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
......@@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
......
......@@ -31,11 +31,13 @@ class PolicyLoss(nn.Module):
super().__init__()
self.clip_eps = clip_eps
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
......@@ -55,14 +57,16 @@ class ValueLoss(nn.Module):
super().__init__()
self.clip_eps = clip_eps
def forward(self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2
surr2 = (values - reward)**2
surr1 = (values_clipped - reward) ** 2
surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return 0.5 * loss
......
......@@ -2,4 +2,4 @@ from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
__all__ = ["OPTActor", "OPTCritic", "OPTRM"]
......@@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
......
......@@ -18,12 +18,14 @@ class OPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
......
......@@ -17,11 +17,13 @@ class OPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
......
......@@ -4,9 +4,9 @@ import torch
import torch.nn.functional as F
def _compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def _compute_approx_kl(
log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
......@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
return approx_kl
def compute_reward(r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def compute_reward(
r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
......@@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
Returns:
torch.Tensor: Action log probs.
"""
logits = output['logits']
logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
......
......@@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
from .utils import low_resource_init
__all__ = [
'llama_load_quant',
'low_resource_init',
"llama_load_quant",
"low_resource_init",
]
from .loader import load_quant
__all__ = [
'load_quant',
"load_quant",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment