"doc/vscode:/vscode.git/clone" did not exist on "d7811a4f32f9d6d729cafa8fbcb43c07e62406da"
Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from enum import Enum, auto
class DTypes(Enum):
FP8E4M3 = auto()
FP8E5M2 = auto()
KFLOAT16 = auto()
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.meta import FP8Meta
@torch.no_grad()
def fp8_matmul_kernel(
mat_a: FP8Tensor,
transpose_a: bool,
mat_b: FP8Tensor,
transpose_b: bool,
use_split_accumulator: bool,
) -> torch.Tensor:
assert (
mat_a.device != "cpu" and mat_b.device != "cpu"
), "The tensors must be on a CUDA device in order to use the FP8 kernel!!"
device = mat_a.device
_empty_tensor = torch.Tensor()
output = torch.empty(mat_a.shape[0], mat_b.shape[1], device=device, dtype=torch.float32)
workspace = torch.empty(33_554_432, dtype=torch.int8, device=device)
accumulate = False
out_dtype = getattr(tex.DType, "kFloat32")
# NOTE: currently TE don't support adding bias in FP8
# along with matmul, it only takes an empty bias
bias = torch.tensor([], dtype=torch.float32)
TE_CONFIG_TRANSPOSE_BIAS = False
mat_a_fp8_meta: FP8Meta = mat_a.fp8_meta
mat_b_fp8_meta: FP8Meta = mat_b.fp8_meta
# NOTE: these are the fixed configs that TE only takes
# so we have to TE the A and B matrix to match these configs
TE_CONFIG_TRANSPOSE_A = True
TE_CONFIG_TRANSPOSE_B = False
SCALE = AMAX = _empty_tensor
mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a
mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b
tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
mat_a_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_A,
mat_b,
mat_b_fp8_meta.inverse_scale,
mat_b_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_B,
output,
SCALE,
out_dtype,
AMAX,
bias,
out_dtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
return output
from typing import Optional, Tuple, TypedDict, Union
import torch
import torch.nn.functional as F
import transformer_engine as te # noqa
from torch import nn
from nanotron.fp8.constants import INITIAL_AMAX, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.kernel import fp8_matmul_kernel
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor
class FP8LinearMeta(TypedDict):
"""FP8 metadata for FP8Linear."""
input_grad: FP8Meta
weight_grad: FP8Meta
output_grad: FP8Meta
class FP8Linear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device: Optional[torch.device] = None):
super().__init__(in_features, out_features, bias, device)
# TODO(xrsrke): add device, and 2 fp8 dtypes
if self.weight.device != torch.device("cpu"):
self.weight = FP8Parameter(self.weight, dtype=DTypes.FP8E4M3)
# NOTE: quantization metadata for input gradients, weight gradients, and output gradients
# TODO(xrsrke): don't fixed this
fp8e4m3_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR),
dtype=DTypes.FP8E4M3,
)
fp8e5m2_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32),
dtype=DTypes.FP8E5M2,
)
self.fp8_meta: FP8LinearMeta = {
# kfloat8_e4m3
"input_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
"weight_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
# kfloat8_e5m2
"output_grad": FP8Meta(amax=1, dtype=DTypes.FP8E5M2, scale=fp8e5m2_scale),
}
def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
# NOTE: only do fp8 kernel if both input and weight are on CUDA device
if input.device == torch.device("cpu") or self.weight.device == torch.device("cpu"):
return F.linear(input, self.weight, self.bias)
# NOTE: just a phony tensor to make pytorch trigger the backward pass
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output, _ = _FP8Matmul.apply(input, self.weight, self.fp8_meta, phony)
# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
# since we will quantize it back to FP8 anyway in the next linear
output = output if self.bias is None else output + self.bias
return output
class _FP8Matmul(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(
ctx, input: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8LinearMeta, phony: torch.Tensor
) -> torch.Tensor:
if type(input) == torch.Tensor:
input = FP8Tensor(input, dtype=DTypes.FP8E4M3)
ctx.save_for_backward(input, weight)
ctx.fp8_meta = fp8_meta
# NOTE: pass FP8Tensor instead of FP8Parameter
output = fp8_matmul_kernel(
mat_a=weight.data, transpose_a=True, mat_b=input, transpose_b=False, use_split_accumulator=False
)
return output, phony
@staticmethod
@torch.no_grad()
def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[torch.Tensor, None, None, None]:
"""
∂L/∂X = ∂L/∂Y @ Wᵀ
∂L/∂W = Xᵀ @ ∂L/∂Y
Source: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html
"""
# TODO(xrsrke): investigate how does grad_output.contiguous() affect the outputs
input, weight = ctx.saved_tensors
if type(grad_output) == torch.Tensor:
grad_output = torch.ones_like(grad_output)
grad_output = grad_output.contiguous()
grad_output = FP8Tensor(grad_output, dtype=DTypes.FP8E5M2)
grad_input = fp8_matmul_kernel(
mat_a=grad_output, transpose_a=True, mat_b=weight, transpose_b=True, use_split_accumulator=True
)
grad_weight = fp8_matmul_kernel(
mat_a=input, transpose_a=False, mat_b=grad_output, transpose_b=False, use_split_accumulator=True
)
weight.grad = grad_weight
return grad_input, None, None, None
from dataclasses import dataclass
from typing import Union
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX
from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype
@dataclass
class FP8Meta:
"""Metadata for FP8Tensor."""
amax: Union[int, float]
scale: torch.Tensor
# TODO(xrsrke): change to Literal[torch.int8, torch.uint8]
dtype: torch.dtype
@property
def te_dtype(self) -> tex.DType:
return convert_torch_dtype_to_te_dtype(self.dtype)
def __post_init__(self):
# NOTE: transformer engine only accepts torch tensors
self.amax = torch.tensor(self.amax, device="cuda") if not isinstance(self.amax, torch.Tensor) else self.amax
@property
def fp8_max(self) -> float:
"""Return the maximum normal value for the current dtype."""
return DTYPE_TO_FP8_MAX[self.dtype]
@property
def inverse_scale(self) -> torch.Tensor:
return 1 / self.scale
def __repr__(self) -> str:
return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype})"
import torch
from torch import nn
from nanotron.fp8.constants import FP8_DTYPES
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor
class FP8Parameter(nn.Parameter):
"""
A custom FP8 parameter class that allows gradients
to flow into FP8 tensors (which are integer tensors).
"""
def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter:
assert isinstance(data, torch.Tensor), "data must be a tensor"
assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter"
assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors"
# TODO(xrsrke): if the tensor is on cpu, then bypass quantization
with torch.no_grad():
# TODO(xrsrke): support take an FP8 Tensor as data
# currently we can't only quantize a tensor to FP8 after the parameter is created
# because it raise "Only Tensors of floating point and complex dtype can require gradients"
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self._data = FP8Tensor(data, dtype=dtype)
return self
@property
def data(self) -> FP8Tensor:
return self._data
@data.setter
def data(self, data: FP8Tensor):
self._data = data
@property
def fp8_meta(self) -> FP8Meta:
return self.data.fp8_meta
def __repr__(self) -> str:
return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}"
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX, FP8_DTYPES, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
class FP8Tensor(torch.Tensor):
"""FP8 Tensor."""
def __new__(cls, tensor: torch.Tensor, dtype: DTypes) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor), "tensor must be a tensor"
assert tensor.dtype not in FP8_DTYPES, "The tensor already quantized to FP8"
# TODO(xrsrke): there is a circular import issue
# between tensor.py and meta.py fix this
from nanotron.fp8.meta import FP8Meta
# TODO(xrsrke): if the tensor is on cpu, then bypass the quantization
# because the current kernels only support gpu tensor
assert tensor.device != torch.device("cpu"), "FP8Tensor only supports CUDA device"
assert isinstance(dtype, DTypes)
amax = tensor.abs().max().clone()
scale = update_scaling_factor(amax, torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), dtype)
fp8_meta = FP8Meta(amax, scale, dtype)
fp8_tensor = convert_tensor_to_fp8(tensor, fp8_meta)
# TODO(xrsrke): move update inverse scaling to FP8Meta's initialization
obj = torch.Tensor._make_subclass(cls, fp8_tensor)
obj.fp8_meta = fp8_meta
return obj
def __repr__(self) -> str:
return f"FP8Tensor({self}, fp8_meta={self.fp8_meta})"
def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType:
# NOTE: transformer engine maintains it own dtype mapping
# so we need to manually map torch dtypes to TE dtypes
TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = {
torch.int32: "kInt32",
torch.float32: "kFloat32",
torch.float16: "kFloat16",
torch.bfloat16: "kBFloat16",
# torch.fp8e5m2: "kFloat8E5M2",
# torch.fp8e4m3: "kFloat8E4M3",
# torch.int8: "kFloat8E5M2",
# torch.uint8: "kFloat8E4M3",
DTypes.FP8E4M3: "kFloat8E4M3",
DTypes.FP8E5M2: "kFloat8E5M2",
DTypes.KFLOAT16: "kFloat16",
}
return getattr(tex.DType, TORCH_DTYPE_TE_DTYPE_NAME_MAPPING[dtype])
# TODO(xrsrke): add type hint for meta after fixing
# circular import between tensor.py and meta.py
def convert_tensor_to_fp8(tensor: torch.Tensor, meta) -> FP8Tensor:
te_dtype = convert_torch_dtype_to_te_dtype(meta.dtype)
# TODO(xrsrke): after casting to fp8, update the scaling factor
# TODO(xrsrke): it's weird that TE only take inverse_scale equal to 1
inverse_scale = torch.tensor(1.0, device=tensor.device, dtype=torch.float32)
return tex.cast_to_fp8(tensor, meta.scale, meta.amax, inverse_scale, te_dtype)
def convert_tensor_from_fp8(tensor: torch.Tensor, meta, dtype: torch.dtype) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor)
assert isinstance(dtype, torch.dtype)
tensor_dtype = convert_torch_dtype_to_te_dtype(meta.dtype)
output_dtype = convert_torch_dtype_to_te_dtype(dtype)
return tex.cast_from_fp8(tensor, meta.inverse_scale, tensor_dtype, output_dtype)
def update_scaling_factor(
amax: torch.Tensor, scaling_factor: torch.Tensor, dtype: DTypes, margin: float = 0
) -> torch.Tensor:
"""
Update the scaling factor to quantize a tensor to FP8.
Credits: https://github.com/Azure/MS-AMP/blob/d562f0f0bcfc9b712fa0726b73428753ff1300ab/msamp/common/tensor/meta.py#L39
"""
assert amax.dtype == torch.float32
# TODO(xrsrke): can we use lower precision for scaling_factor?
assert scaling_factor.dtype == torch.float32
# NOTE: Since fp8_max is a fixed number based on two FP8 data types,
# we prefer not to take fp8_max in the input arguments.
fp8_max = torch.tensor(DTYPE_TO_FP8_MAX[dtype], dtype=torch.float32)
# NOTE: torch.jit only take a concrete value rather than a DTYPE_TO_FP8_MAX[dtype],
# so we create an inner function to bypass that
@torch.jit.script
def _inner(amax: torch.Tensor, fp8_max: torch.Tensor, scaling_factor: torch.Tensor, margin: float):
# NOTE: calculate the number of bits to shift the exponent
ratio = fp8_max / amax
exp = torch.floor(torch.log2(ratio)) - margin
new_scaling_factor = torch.round(torch.pow(2, torch.abs(exp)))
new_scaling_factor = torch.where(amax > 0.0, new_scaling_factor, scaling_factor)
new_scaling_factor = torch.where(torch.isfinite(amax), new_scaling_factor, scaling_factor)
new_scaling_factor = torch.where(exp < 0, 1 / new_scaling_factor, new_scaling_factor)
return new_scaling_factor
return _inner(amax, fp8_max, scaling_factor, margin)
import torch
import transformer_engine as te # noqa
from nanotron.fp8.constants import FP8_GPU_NAMES
def is_fp8_available() -> bool:
"""Check if FP8 is available on the current device."""
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(torch.cuda.current_device()).lower()
return any(gpu_name in device_name for gpu_name in FP8_GPU_NAMES)
else:
return False
from .sampler import BasicSampler, GreedySampler, Sampler, SamplerType, TopKSampler, TopPSampler
__all__ = ["BasicSampler", "GreedySampler", "Sampler", "SamplerType", "TopKSampler", "TopPSampler"]
This diff is collapsed.
import collections
import contextlib
from torch import nn
class Store(collections.defaultdict):
"""
We use the store to locally store on gpu some states so that we don't have to communicate.
This is useful at inference if we don't want to recompute kv_cache for example, or that we don't want to communicate it through the pipeline
"""
def __init__(self):
super().__init__(dict)
def flush(self):
# TODO @thomasw21: There's probably a simpler way than doing this.
for key in list(self.keys()):
del self[key]
class AttachableStore:
def _attach_store(self, store: Store):
assert not hasattr(self, "_store"), "You can't assign a store when there's already one attached"
self._store = store
def _detach_store(self):
delattr(self, "_store")
def get_local_store(self):
if hasattr(self, "_store"):
if isinstance(self, nn.Module):
assert self.training is False, "Store is used only in evaluation mode"
return self._store[id(self)]
else:
return None
@contextlib.contextmanager
def attach_store(model: nn.Module, store: Store):
list_module_containing_store = []
for module in model.modules():
if not isinstance(module, AttachableStore):
continue
module._attach_store(store)
list_module_containing_store.append(module)
try:
yield
finally:
for module in list_module_containing_store:
module._detach_store()
from dataclasses import dataclass
from enum import Enum, auto
from typing import Sequence
import torch
from nanotron import distributed as dist
def all_gather_batches(in_tensor: torch.Tensor, in_split: Sequence[int], group: dist.ProcessGroup) -> torch.Tensor:
# All gather along first dimension, allow un-equal splits
out_tensor = torch.empty((sum(in_split),) + in_tensor.shape[1:], dtype=in_tensor.dtype, device=in_tensor.device)
out_split_list = list(torch.split(out_tensor, in_split, dim=0))
dist.all_gather(out_split_list, in_tensor, group=group)
return out_tensor
class SamplerType(Enum):
TOP_P = auto()
TOP_K = auto()
GREEDY = auto()
BASIC = auto()
class Sampler:
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@dataclass
class TopPSampler(Sampler):
pg: dist.ProcessGroup
p: float = 0.9
temperature: float = 1.0
filter_value: float = 0.0
min_tokens_to_keep: int = 1
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather logits from all vocab shards but shard on batch, tp_rank first
sharded_logits_out = torch.empty(
(total_out_size, vocab_per_shard),
dtype=sharded_logits.dtype,
device=sharded_logits.device,
) # [pg_size * sharded_batch_size, vocab_per_shard]
local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))
dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)
logits = torch.cat(local_sharded_logits_out, dim=-1) # [sharded_batch_size, vocab_size]
probs = torch.softmax(logits.to(dtype=torch.float) / self.temperature, dim=-1) # [batch_size, vocab_size]
# Sort the probs and their corresponding indices in descending order
sorted_probs, sorted_indices = torch.sort(probs, descending=False, dim=-1)
# Calculate the cumulative sum of the sorted probs
# the bfloat16 type is not accurate enough for the cumulative sum
cumulative_probs = torch.cumsum(sorted_probs, dim=-1, dtype=torch.float) # [batch_size, vocab_size]
# Find the smallest set of indices for which the cumulative probability mass exceeds p
sorted_indices_to_remove = cumulative_probs <= (1 - self.p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Construct the probability mask for original indices
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
filter_probs = probs.masked_fill(indices_to_remove, self.filter_value)
sampled_indices = torch.multinomial(filter_probs, num_samples=1)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sampled_indices, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class GreedySampler(Sampler):
pg: dist.ProcessGroup
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local max logit and its index
# Note that max is deterministic, and always takes the first one.
max_values, max_indices = sharded_logits.max(dim=-1, keepdim=True) # [batch_size, 1]
# Add offset to the max indices
# TODO: We're assuming that TensorColumnLinear shards in a specific manner, i.e. rank 0 gets the first.
# It might require us to expose something from TensorColumnLinear.
max_indices = max_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather max logits and their indices from all shards, tp_rank first
max_values_out_mat = torch.empty(
(total_out_size, 1),
dtype=max_values.dtype,
device=max_values.device,
)
max_indices_out_mat = torch.empty(
(total_out_size, 1),
dtype=max_indices.dtype,
device=max_indices.device,
)
local_max_values_in = list(torch.split(max_values, in_split, dim=0))
local_max_indices_in = list(torch.split(max_indices, in_split, dim=0))
local_max_values_out = list(torch.split(max_values_out_mat, out_split, dim=0))
local_max_indices_out = list(torch.split(max_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_max_values_out, local_max_values_in, group=self.pg)
dist.all_to_all(local_max_indices_out, local_max_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_max_values = torch.cat(local_max_values_out, dim=-1) # [sharded_batch_size, num_shards]
sharded_max_indices = torch.cat(local_max_indices_out, dim=-1) # [sharded_batch_size, num_shards]
# Find global max logit across all shards
# Note that max is deterministic, and always takes the first one.
# [sharded_batch_size, 1]
_global_max_values, global_max_indices = sharded_max_values.max(dim=-1, keepdim=True)
# Select the corresponding token index from the offsetted gathered indices
sharded_selected_tokens = sharded_max_indices.gather(1, global_max_indices)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sharded_selected_tokens, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class TopKSampler(Sampler):
pg: dist.ProcessGroup
k: int = 50
temperature: float = 1.0
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local top-k logits and their indices
local_top_k_values, local_top_k_indices = torch.topk(sharded_logits, self.k, dim=-1)
# Add offset to the indices
local_top_k_indices = local_top_k_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split local_top_k_values into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# The last shard could be smaller than shard_batch_size
local_top_k_values_in = list(torch.split(local_top_k_values, in_split, dim=0))
local_tok_k_indices_in = list(torch.split(local_top_k_indices, in_split, dim=0))
# Prepare tensors for all-to-all operation
# Gather top-k logits and their indices from all shards, tp_rank first
top_k_values_out_mat = torch.empty(
(total_out_size,) + local_top_k_values.shape[1:],
dtype=local_top_k_values.dtype,
device=local_top_k_values.device,
)
top_k_indices_out_mat = torch.empty(
(total_out_size,) + local_top_k_indices.shape[1:],
dtype=local_top_k_indices.dtype,
device=local_top_k_indices.device,
)
local_top_k_values_out = list(torch.split(top_k_values_out_mat, out_split, dim=0))
local_top_k_indices_out = list(torch.split(top_k_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_top_k_values_out, local_top_k_values_in, group=self.pg)
dist.all_to_all(local_top_k_indices_out, local_tok_k_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_local_top_k_values = torch.cat(local_top_k_values_out, dim=-1) # [sharded_batch_size, k * num_shards]
sharded_local_top_k_indices = torch.cat(
local_top_k_indices_out, dim=-1
) # [sharded_batch_size, k * num_shards]
# Select global top-k from the gathered top-k, now the top-k is across all vocab, batch_size is sharded
sharded_top_k_values, sharded_top_k_indices = torch.topk(
sharded_local_top_k_values, self.k, dim=-1
) # [sharded_batch_size, k]
# Select corresponding indices from the gathered indices
sharded_top_k_indices = sharded_local_top_k_indices.gather(
-1, sharded_top_k_indices
) # [sharded_batch_size, k]
# Apply temperature and compute softmax probabilities
probs = torch.softmax(sharded_top_k_values.to(dtype=torch.float) / self.temperature, dim=-1)
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1) # [sharded_batch_size]
# Select the corresponding token index from the global top-k indices
new_decoder_input_ids = sharded_top_k_indices.gather(-1, sampled_indices) # [sharded_batch_size]
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(new_decoder_input_ids, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class BasicSampler(Sampler):
"""Basic sampler that samples from the full vocab according to the logits."""
pg: dist.ProcessGroup
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
# We will cross batch and vocab shards to sample from the full vocab and a part of the batch
# (right now logits are sharded on vocab and batch, so we need to do all-to-all)
batch_size, vocab_per_shard = sharded_logits.shape
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather logits from all vocab shards but shard on batch, tp_rank first
sharded_logits_out = torch.empty(
(total_out_size, vocab_per_shard),
dtype=sharded_logits.dtype,
device=sharded_logits.device,
) # [pg_size * sharded_batch_size, vocab_per_shard]
local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))
dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)
logits = torch.cat(local_sharded_logits_out, dim=-1) # [sharded_batch_size, vocab_size]
probs = torch.softmax(logits.to(dtype=torch.float), dim=-1) # [batch_size, vocab_size]
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sampled_indices, in_split, group=self.pg)
return gathered_new_decoder_input_ids
This diff is collapsed.
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities. """
import logging
import os
import sys
from dataclasses import dataclass
from functools import lru_cache
from logging import (
CRITICAL,
DEBUG,
ERROR,
FATAL,
INFO,
NOTSET,
WARNING,
Formatter,
Logger,
)
from typing import TYPE_CHECKING, List, Optional, Union
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
if TYPE_CHECKING:
from nanotron.config import LoggingArgs
from nanotron.parallel import ParallelContext
log_levels = {
"debug": DEBUG,
"info": INFO,
"warning": WARNING,
"error": ERROR,
"critical": CRITICAL,
"fatal": FATAL,
"notset": NOTSET,
}
class NewLineStreamHandler(logging.StreamHandler):
"""
We want to apply formatter before each new line
https://stackoverflow.com/a/38458877
"""
def emit(self, record):
lines = record.msg.split("\n")
for line in lines:
record.msg = line
super().emit(record)
DEFAULT_HANDLER = NewLineStreamHandler()
DEFAULT_LOG_LEVEL = logging.WARNING
LIBRARY_NAME = __name__.split(".")[0]
def _get_default_logging_level():
"""
If NANOTRON_LOGGING_LEVEL env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("NANOTRON_LOGGING_LEVEL", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option NANOTRON_LOGGING_LEVEL={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return DEFAULT_LOG_LEVEL
def get_library_root_logger() -> Logger:
return get_logger(LIBRARY_NAME)
def _configure_library_root_logger() -> None:
library_root_logger = get_library_root_logger()
library_root_logger.addHandler(DEFAULT_HANDLER)
library_root_logger.setLevel(_get_default_logging_level())
def _reset_library_root_logger() -> None:
library_root_logger = get_library_root_logger()
library_root_logger.setLevel(logging.NOTSET)
def get_logger(name: Optional[str] = None, log_level: Optional[str] = None) -> Logger:
"""
Return a logger with the specified name.
"""
logger_already_exists = isinstance(logging.root.manager.loggerDict.get(name, None), Logger)
logger = logging.getLogger(name)
if logger_already_exists or name is None:
# if name is None we return root logger
return logger
# If the logger is in a `nanotron` module then we remove the capability to propagate
if LIBRARY_NAME == name.split(".", 1)[0]:
if log_level is not None:
logger.setLevel(log_level.upper())
elif LEVEL is not None:
logger.setLevel(LEVEL)
else:
logger.setLevel(_get_default_logging_level())
if HANDLER is not None:
logger.handlers.clear()
logger.addHandler(HANDLER)
logger.propagate = False
return logger
def get_verbosity() -> int:
"""
Return the current level for the Nanotron root logger as an int.
Returns:
:obj:`int`: The logging level.
.. note::
Nanotron has following logging levels:
- 50: ``nanotron.logging.CRITICAL`` or ``nanotron.logging.FATAL``
- 40: ``nanotron.logging.ERROR``
- 30: ``nanotron.logging.WARNING`` or ``nanotron.logging.WARN``
- 20: ``nanotron.logging.INFO``
- 10: ``nanotron.logging.DEBUG``
"""
return get_library_root_logger().getEffectiveLevel()
LEVEL = None
def set_verbosity(verbosity: int) -> None:
"""
Set the verbosity level for the all `nanotron` loggers.
Args:
verbosity (:obj:`int`):
Logging level, e.g., one of:
- ``nanotron.logging.CRITICAL`` or ``nanotron.logging.FATAL``
- ``nanotron.logging.ERROR``
- ``nanotron.logging.WARNING`` or ``nanotron.logging.WARN``
- ``nanotron.logging.INFO``
- ``nanotron.logging.DEBUG``
"""
all_nanotron_loggers = {
name: logger
for name, logger in logging.root.manager.loggerDict.items()
if isinstance(logger, Logger) and (name.startswith(f"{LIBRARY_NAME}.") or name == LIBRARY_NAME)
}
for name, logger in all_nanotron_loggers.items():
logger.setLevel(verbosity)
# We update all handles to be at the current verbosity as well.
for handle in logger.handlers:
handle.setLevel(verbosity)
global LEVEL
LEVEL = verbosity
HANDLER = None
def set_formatter(formatter: logging.Formatter) -> None:
"""
Set a new custom formatter as the current handler.
Note: it's important to first set level and then
:param formatter:
:return:
"""
handler = NewLineStreamHandler(sys.stdout)
handler.setFormatter(formatter)
handler.setLevel(get_verbosity())
handler.flush = sys.stderr.flush
all_nanotron_loggers = {
name: logger
for name, logger in logging.root.manager.loggerDict.items()
if isinstance(logger, Logger) and (name.startswith(f"{LIBRARY_NAME}.") or name == LIBRARY_NAME)
}
for name, logger in all_nanotron_loggers.items():
# We keep only a single handler
logger.handlers.clear()
logger.addHandler(handler)
global HANDLER
HANDLER = handler
def log_rank(
msg: str,
logger: Logger,
level: int,
group: Optional[dist.ProcessGroup] = None,
rank: Optional[int] = None,
**kwargs,
):
"""Log only if the current process is the rank specified."""
# Use default group is group is not provided
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
# rank is None means everyone logs
if rank is None or dist.get_rank(group) == rank:
logger.log(level, msg, **kwargs)
@lru_cache(maxsize=None)
def warn_once(
msg: str, logger: Logger, group: Optional[dist.ProcessGroup] = None, rank: Optional[int] = None, **kwargs
):
log_rank(msg=msg, logger=logger, level=logging.WARNING, group=group, rank=rank, **kwargs)
def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str:
if abs(num) < 1:
return "{:.3g}".format(num)
SIZES = ["", "K", "M", "G", "T", "P", "E"]
num = float("{:.3g}".format(num))
magnitude = 0
i = 0
while abs(num) >= 1000 and i < len(SIZES) - 1:
magnitude += 1
num /= 1000.0 if not divide_by_1024 else 1024.0
i += 1
return "{}{}".format("{:f}".format(num).rstrip("0").rstrip("."), SIZES[magnitude])
def log_memory(logger: logging.Logger):
log_rank(
f" Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB."
f" Peak allocated {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB."
f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB",
logger=logger,
level=logging.INFO,
rank=0,
)
torch.cuda.reset_peak_memory_stats()
@dataclass
class LogItem:
tag: str
scalar_value: Union[float, int, str]
log_format: Optional[str] = None
@dataclass
class LoggerWriter:
global_step: int
def add_scalar(self, tag: str, scalar_value: Union[float, int], log_format=None) -> str:
if log_format == "human_format":
log_str = f"{tag}: {human_format(scalar_value)}"
else:
log_str = f"{tag}: {scalar_value:{log_format}}" if log_format is not None else f"{tag}: {scalar_value}"
return log_str
def add_scalars_from_list(self, log_entries: List[LogItem], iteration_step: int):
log_strs = [f"iteration: {iteration_step} / {self.global_step}"]
log_strs += [
self.add_scalar(log_item.tag, log_item.scalar_value, log_item.log_format) for log_item in log_entries
]
log_str = " | ".join(log_strs)
log_rank(log_str, logger=get_logger(__name__), level=logging.INFO)
def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext):
node_name = os.environ.get("SLURMD_NODENAME")
expert_parallel_log = (
f"|EXP={dist.get_rank(parallel_context.expert_pg)}" if parallel_context.expert_parallel_size > 1 else ""
)
formatter = Formatter(
fmt=f"%(asctime)s [%(levelname)s|DP={dist.get_rank(parallel_context.dp_pg)}|PP={dist.get_rank(parallel_context.pp_pg)}|"
f"TP={dist.get_rank(parallel_context.tp_pg)}{expert_parallel_log}{'|' + node_name if node_name else ''}]: %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
log_level = log_levels[logging_level]
# main root logger
root_logger = get_logger()
root_logger.setLevel(log_level)
handler = NewLineStreamHandler(sys.stdout)
handler.setLevel(log_level)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Nanotron
set_verbosity(log_level)
set_formatter(formatter=formatter)
def set_ranks_logging_level(parallel_context: ParallelContext, logging_config: "LoggingArgs"):
if dist.get_rank(parallel_context.world_pg) == 0:
if logging_config.log_level is not None:
set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context)
else:
if logging_config.log_level_replica is not None:
set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context)
_configure_library_root_logger()
# flake8: noqa
from .base import DTypeInvariantTensor, NanotronModel, build_model, check_model_has_grad, init_on_device_and_dtype
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment