Unverified Commit b69a62d5 authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

[BLOOM] Clean modeling code (#18344)



* Cleanup some code

* Improve signatures

* Try to reduce the number of reshape/copies

* I don't think we actually need the layer_num scaling trick

* No need for duplication

* Try to fix beam_search

* Fix beam search

* Removing layer num normalization seems to be breaking

* Not sure self.layer_number normalization actually matters

* Try and be backward compatible

* Try to fix beam_search

* Revert attempt to be backward compatible

* Improve documentation on past_key_values format

* Optimize the device allocation in case of hidden_states in multiple devices

* No need to manually cast the values to a specific device

* Rename with long version of variables

* Improve type hinting

* Add comment that explains that some methods return views

* Actually i think the attention casting only makes sense when we use torch.float16

* We don't actually need layer_number to be passed anymore

* Fix FX test

* Bypass torch.baddbmm

* Apply suggestions from code review

* Add comment about support for torchScript v1.11

* fix ONNX support for bloom (#18456)
Co-authored-by: default avatarNiklas Muennighoff <n.muennighoff@gmail.com>
Co-authored-by: default avatarNouamane Tazi <nouamane98@gmail.com>
parent 02b176c4
...@@ -214,14 +214,19 @@ class BloomOnnxConfig(OnnxConfigWithPast): ...@@ -214,14 +214,19 @@ class BloomOnnxConfig(OnnxConfigWithPast):
batch, seqlen = common_inputs["input_ids"].shape batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values # Not using the same length for past_key_values
past_key_values_length = seqlen + 2 past_key_values_length = seqlen + 2
past_shape = ( head_dim = self._config.hidden_size // self.num_attention_heads
batch, past_key_shape = (
batch * self.num_attention_heads,
head_dim,
past_key_values_length, past_key_values_length,
self.num_attention_heads, )
self._config.hidden_size // self.num_attention_heads, past_value_shape = (
batch * self.num_attention_heads,
past_key_values_length,
head_dim,
) )
ordered_inputs["past_key_values"] = [ ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
] ]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"] ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
import math import math
import warnings import warnings
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import ( from ...modeling_outputs import (
...@@ -52,102 +53,100 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -52,102 +53,100 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
""" """
Make causal mask used for bi-directional self-attention. Make causal mask used for self-attention.
""" """
batch_size, target_length = input_ids_shape batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min) mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
mask_cond = torch.arange(mask.size(-1)) # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) seq_ids = torch.arange(target_length, device=device)
mask.masked_fill_(intermediate_mask, 0) mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
mask = mask.to(dtype)
if past_key_values_length > 0: if past_key_values_length > 0:
mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask return expanded_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
""" """
batch_size, source_length = mask.size() batch_size, src_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else source_length tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
""" """
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on `softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args: Args:
Returns tensor shaped (batch_size * n_head, 1, max_seq_len) Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`): attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len). Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
n_head (`int`, *required*): num_heads (`int`, *required*):
number of heads number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor dtype of the output tensor
device (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
""" """
closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) batch_size, seq_length = attention_mask.shape
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers) slopes = torch.pow(base, powers)
if closest_power_of_2 != n_head: if closest_power_of_2 != num_heads:
extra_base = torch.tensor( extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
) )
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly # => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias: # This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
# batch_size = 1, n_head = n_head, query_length arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
alibi = slopes.unsqueeze(-1) * arange_tensor
alibi = alibi * attention_mask[:, None]
return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype)
def dropout_add(x, residual, prob, training): def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
""" """
Dropout add function Dropout add function
Args: Args:
x (`torch.tensor`, *required*): x (`torch.tensor`, *required*):
input tensor input tensor
residual (`torch.tensor`, *rquired*): residual (`torch.tensor`, *required*):
esidual tensor esidual tensor
prob (`float`, *required*): prob (`float`, *required*):
dropout probability dropout probability
training (`bool`, *required*): training (`bool`, *required*):
training mode training mode
""" """
out = nn.functional.dropout(x, p=prob, training=training) out = F.dropout(x, p=prob, training=training)
out = residual + out out = residual + out
return out return out
def bloom_gelu_forward(x): def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
""" """
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
make the model jitable. make the model jitable.
...@@ -159,7 +158,7 @@ def bloom_gelu_forward(x): ...@@ -159,7 +158,7 @@ def bloom_gelu_forward(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
def bloom_gelu_back(g, x): def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
""" """
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) + gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
0.3989423 * x * torch.exp(-0.5 * x * x) 0.3989423 * x * torch.exp(-0.5 * x * x)
...@@ -179,12 +178,12 @@ def bloom_gelu_back(g, x): ...@@ -179,12 +178,12 @@ def bloom_gelu_back(g, x):
class GeLUFunction(torch.autograd.Function): class GeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input) ctx.save_for_backward(input)
return bloom_gelu_forward(input) return bloom_gelu_forward(input)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
input = ctx.saved_tensors input = ctx.saved_tensors
tmp = bloom_gelu_back(grad_output, input) tmp = bloom_gelu_back(grad_output, input)
return tmp return tmp
...@@ -197,13 +196,12 @@ class BloomGelu(nn.Module): ...@@ -197,13 +196,12 @@ class BloomGelu(nn.Module):
copied from Megatron-DeepSpeed code and adapted for our needs copied from Megatron-DeepSpeed code and adapted for our needs
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329 See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training: if self.training:
return GeLUFunction.apply(x) return GeLUFunction.apply(x)
else: else:
...@@ -211,7 +209,7 @@ class BloomGelu(nn.Module): ...@@ -211,7 +209,7 @@ class BloomGelu(nn.Module):
class BloomAttention(nn.Module): class BloomAttention(nn.Module):
def __init__(self, config, layer_number=None): def __init__(self, config: BloomConfig):
super().__init__() super().__init__()
self.pretraining_tp = config.pretraining_tp self.pretraining_tp = config.pretraining_tp
...@@ -230,106 +228,131 @@ class BloomAttention(nn.Module): ...@@ -230,106 +228,131 @@ class BloomAttention(nn.Module):
) )
# Layer-wise attention scaling # Layer-wise attention scaling
self.layer_number = max(1, layer_number) self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number self.beta = 1.0
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
def _split_heads(self, fused_qkv): def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Split the last dimension into (num_heads, head_dim) Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
""" """
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
# fused_qkv = fused_qkv.transpose(1, 0)
fused_qkv = fused_qkv.reshape(new_tensor_shape)
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
return torch.split(fused_qkv, self.head_dim, -1)
def _merge_heads(self, x):
# What we want to achieve is: # What we want to achieve is:
# batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // self.num_heads
# First view to decompose the batch size # First view to decompose the batch size
# batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim) x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
# batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3) x = x.permute(0, 2, 1, 3)
# batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim) return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
residual, residual: torch.Tensor,
layer_past=None, alibi: torch.Tensor,
attention_mask=None, attention_mask: torch.Tensor,
alibi=None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
use_cache=False, use_cache: bool = False,
output_attentions=False, output_attentions: bool = False,
): ):
alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim] # 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] # concatenate along seq_length dimension:
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) # - key: [batch_size * self.num_heads, head_dim, kv_length]
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) # - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=2)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True: if use_cache is True:
present = (key_layer, value_layer) present = (key_layer, value_layer)
else: else:
present = None present = None
beta = 1.0 / self.layer_number # [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] matmul_result = alibi.baddbmm(
matmul_result = (1.0 / self.norm_factor) * torch.bmm( batch1=query_layer,
query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]), batch2=key_layer,
key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]), beta=self.beta,
) + beta * alibi alpha=self.inv_norm_factor,
)
# change view to [batch_size, num_heads, q_length, k_length] # change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype input_dtype = attention_scores.dtype
attn_weights = (attention_scores * self.layer_number) + attention_mask # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) if input_dtype == torch.float16:
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) attention_scores = attention_scores.to(torch.float)
attention_probs = attention_probs * (~attention_mask.to(torch.bool)) attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
# [batch_size, num_heads, q_length, k_length] attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None: if head_mask is not None:
attention_probs = attention_probs * head_mask attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, k_length] # change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(matmul_result.shape) attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim] # matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm( context_layer = torch.bmm(attention_probs_reshaped, value_layer)
attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))
)
# change view [batch_size, num_heads, q_length, head_dim] # change view [batch_size, num_heads, q_length, head_dim]
context_layer = self._merge_heads(context_layer) context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact: if self.pretraining_tp > 1 and self.slow_but_exact:
slices = context_layer.shape[-1] / self.pretraining_tp slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer) output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp): for i in range(self.pretraining_tp):
output_tensor = output_tensor + nn.functional.linear( output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)], context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
) )
...@@ -346,7 +369,7 @@ class BloomAttention(nn.Module): ...@@ -346,7 +369,7 @@ class BloomAttention(nn.Module):
class BloomMLP(nn.Module): class BloomMLP(nn.Module):
def __init__(self, config): def __init__(self, config: BloomConfig):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -357,14 +380,14 @@ class BloomMLP(nn.Module): ...@@ -357,14 +380,14 @@ class BloomMLP(nn.Module):
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
def forward(self, hidden_states, residual): def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact: if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual) intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp): for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + nn.functional.linear( intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
) )
...@@ -377,13 +400,13 @@ class BloomMLP(nn.Module): ...@@ -377,13 +400,13 @@ class BloomMLP(nn.Module):
class BloomBlock(nn.Module): class BloomBlock(nn.Module):
def __init__(self, config, layer_number=None): def __init__(self, config: BloomConfig):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.n_head = config.n_head self.num_heads = config.n_head
self.self_attention = BloomAttention(config, layer_number=layer_number) self.self_attention = BloomAttention(config)
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config) self.mlp = BloomMLP(config)
...@@ -393,13 +416,13 @@ class BloomBlock(nn.Module): ...@@ -393,13 +416,13 @@ class BloomBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
layer_past=None, alibi: torch.Tensor,
attention_mask=None, attention_mask: torch.Tensor,
head_mask=None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, use_cache: bool = False,
alibi=None, output_attentions: bool = False,
): ):
# hidden_states: [batch_size, seq_length, hidden_size] # hidden_states: [batch_size, seq_length, hidden_size]
...@@ -462,9 +485,9 @@ class BloomPreTrainedModel(PreTrainedModel): ...@@ -462,9 +485,9 @@ class BloomPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module: nn.Module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (nn.Linear)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
...@@ -478,7 +501,7 @@ class BloomPreTrainedModel(PreTrainedModel): ...@@ -478,7 +501,7 @@ class BloomPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
if isinstance(module, BloomModel): if isinstance(module, BloomModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -501,9 +524,8 @@ BLOOM_START_DOCSTRING = r""" ...@@ -501,9 +524,8 @@ BLOOM_START_DOCSTRING = r"""
BLOOM_INPUTS_DOCSTRING = r""" BLOOM_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
sequence tokens in the vocabulary.
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
`input_ids`. `input_ids`.
...@@ -516,6 +538,10 @@ BLOOM_INPUTS_DOCSTRING = r""" ...@@ -516,6 +538,10 @@ BLOOM_INPUTS_DOCSTRING = r"""
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed. their past given to this model should not be passed as `input_ids` as they have already been computed.
Each element of `past_key_values` is a tuple (past_key, past_value):
- past_key: [batch_size * num_heads, head_dim, kv_length]
- past_value: [batch_size * num_heads, kv_length, head_dim]
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
...@@ -555,19 +581,18 @@ BLOOM_INPUTS_DOCSTRING = r""" ...@@ -555,19 +581,18 @@ BLOOM_INPUTS_DOCSTRING = r"""
BLOOM_START_DOCSTRING, BLOOM_START_DOCSTRING,
) )
class BloomModel(BloomPreTrainedModel): class BloomModel(BloomPreTrainedModel):
def __init__(self, config): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.n_head = config.n_head self.num_heads = config.n_head
# Embedding + LN Embedding # Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)]) self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
# Final Layer Norm # Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
...@@ -580,25 +605,29 @@ class BloomModel(BloomPreTrainedModel): ...@@ -580,25 +605,29 @@ class BloomModel(BloomPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embeddings return self.word_embeddings
def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None combined_attention_mask = None
if input_shape[-1] > 1: device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, device=device, past_key_values_length=past_key_values_length
).to(attention_mask.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
) )
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask return combined_attention_mask
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings self.word_embeddings = new_embeddings
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
...@@ -610,17 +639,17 @@ class BloomModel(BloomPreTrainedModel): ...@@ -610,17 +639,17 @@ class BloomModel(BloomPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**deprecated_arguments **deprecated_arguments
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False: if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn( warnings.warn(
...@@ -641,10 +670,9 @@ class BloomModel(BloomPreTrainedModel): ...@@ -641,10 +670,9 @@ class BloomModel(BloomPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() batch_size, seq_length = input_ids.shape
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] batch_size, seq_length, _ = inputs_embeds.shape
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
...@@ -653,8 +681,8 @@ class BloomModel(BloomPreTrainedModel): ...@@ -653,8 +681,8 @@ class BloomModel(BloomPreTrainedModel):
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_head x N x N # attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x n_head x N x N # head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None: if inputs_embeds is None:
...@@ -662,27 +690,28 @@ class BloomModel(BloomPreTrainedModel): ...@@ -662,27 +690,28 @@ class BloomModel(BloomPreTrainedModel):
hidden_states = self.word_embeddings_layernorm(inputs_embeds) hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None presents = () if use_cache else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
current_sequence_length = hidden_states.shape[1] seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
if past_key_values[0] is not None: if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] past_key_values_length = past_key_values[0][0].shape[2]
current_sequence_length += past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device) attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else: else:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device) alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
...@@ -700,14 +729,14 @@ class BloomModel(BloomPreTrainedModel): ...@@ -700,14 +729,14 @@ class BloomModel(BloomPreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, use_cache, output_attentions, alibi) return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, alibi,
causal_mask, causal_mask,
head_mask[i], head_mask[i],
) )
...@@ -735,8 +764,6 @@ class BloomModel(BloomPreTrainedModel): ...@@ -735,8 +764,6 @@ class BloomModel(BloomPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = hidden_states.view(output_shape)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
...@@ -758,7 +785,7 @@ class BloomModel(BloomPreTrainedModel): ...@@ -758,7 +785,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config) self.transformer = BloomModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
...@@ -769,16 +796,20 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -769,16 +796,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings: torch.Tensor):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(
# only last token for inputs_ids if past is defined in kwargs self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
# only last token for input_ids if past is not None
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"past_key_values": past, "past_key_values": past,
...@@ -795,16 +826,16 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -795,16 +826,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**deprecated_arguments **deprecated_arguments
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r""" r"""
...@@ -845,9 +876,12 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -845,9 +876,12 @@ class BloomForCausalLM(BloomPreTrainedModel):
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
...@@ -862,14 +896,36 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -862,14 +896,36 @@ class BloomForCausalLM(BloomPreTrainedModel):
) )
@staticmethod @staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step. beam_idx at every generation step.
Output shares the same memory storage as `past`.
""" """
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
batch_size = len(beam_idx)
num_heads = batch_size_times_num_heads // batch_size
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
return tuple( return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) (
layer_past[0]
.view(batch_size, num_heads, head_dim, seq_length)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1]
.view(batch_size, num_heads, seq_length, head_dim)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past for layer_past in past
) )
...@@ -892,7 +948,7 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -892,7 +948,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
class BloomForSequenceClassification(BloomPreTrainedModel): class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = BloomModel(config) self.transformer = BloomModel(config)
...@@ -910,16 +966,16 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ...@@ -910,16 +966,16 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**deprecated_arguments **deprecated_arguments
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r""" r"""
...@@ -966,7 +1022,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ...@@ -966,7 +1022,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
sequence_lengths = -1 sequence_lengths = -1
else: else:
if input_ids is not None: if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
...@@ -994,7 +1050,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ...@@ -994,7 +1050,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
loss = loss_fct(pooled_logits, labels) loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification": elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "multi_label_classification": elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels) loss = loss_fct(pooled_logits, labels)
...@@ -1021,7 +1077,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ...@@ -1021,7 +1077,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
class BloomForTokenClassification(BloomPreTrainedModel): class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1047,16 +1103,16 @@ class BloomForTokenClassification(BloomPreTrainedModel): ...@@ -1047,16 +1103,16 @@ class BloomForTokenClassification(BloomPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**deprecated_arguments **deprecated_arguments
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r""" r"""
...@@ -1095,8 +1151,11 @@ class BloomForTokenClassification(BloomPreTrainedModel): ...@@ -1095,8 +1151,11 @@ class BloomForTokenClassification(BloomPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[2:] output = (logits,) + transformer_outputs[2:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment