Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
import os
import pytest
import torch
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from torch import nn as torch_nn
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_column_linear(
tp: int,
dp: int,
pp: int,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather")
init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)(
tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
)
def _test_column_linear(
parallel_context: ParallelContext,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
in_features = 2
out_features_per_tp_rank = 3
out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank
# Sharded
column_linear = TensorParallelColumnLinear(
in_features=in_features,
out_features=out_features,
pg=parallel_context.tp_pg,
mode=tp_mode,
device="cuda",
async_communication=async_communication,
tp_recompute_allgather=tp_recompute_allgather,
)
# Un-sharded
reference_linear = torch_nn.Linear(in_features=in_features, out_features=out_features, device="cuda")
# Copy weights/bias from sharded to un-sharded
with torch.inference_mode():
dist.all_gather(
tensor_list=list(reference_linear.weight.split(out_features_per_tp_rank, dim=0)),
tensor=column_linear.weight,
group=parallel_context.tp_pg,
)
dist.all_gather(
tensor_list=list(reference_linear.bias.split(out_features_per_tp_rank, dim=0)),
tensor=column_linear.bias,
group=parallel_context.tp_pg,
)
# Generate random input
random_input: torch.Tensor
sharded_random_input: torch.Tensor
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
batch_size = 5
random_input = torch.randn(batch_size, in_features, device="cuda")
# synchronize random_input across tp
dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)
sharded_random_input = random_input
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
sharded_batch_size = 5
sharded_random_input = torch.randn(sharded_batch_size, in_features, device="cuda")
if parallel_context.tp_pg.size() > 1:
random_input = torch.empty(
sharded_batch_size * parallel_context.tp_pg.size(),
*(sharded_random_input.shape[1:]),
device=sharded_random_input.device,
dtype=sharded_random_input.dtype,
)
dist.all_gather_into_tensor(random_input, sharded_random_input, group=parallel_context.tp_pg)
else:
random_input = sharded_random_input
else:
ValueError(f"Unsupported mode: {tp_mode}")
# It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage
sharded_random_input = sharded_random_input.clone()
random_input.requires_grad = True
sharded_random_input.requires_grad = True
# Test that we get the same output after forward pass
sharded_output = column_linear(sharded_random_input)
reference_output = reference_linear(random_input)
# TODO @thomasw21: Tune tolerance
try:
torch.testing.assert_close(
sharded_output,
reference_output[
:,
dist.get_rank(parallel_context.tp_pg)
* out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* out_features_per_tp_rank,
],
)
except BaseException as e:
print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: FAIL.")
dist.barrier()
raise e
print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: SUCCESS.")
dist.barrier()
# Test that we get the same gradient after backward pass
sharded_output.sum().backward()
reference_output.sum().backward()
hidden_dim_slice = slice(
dist.get_rank(parallel_context.tp_pg) * out_features_per_tp_rank,
(dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank,
)
torch.testing.assert_close(
column_linear.weight.grad,
reference_linear.weight.grad[hidden_dim_slice],
)
torch.testing.assert_close(
column_linear.bias.grad,
reference_linear.bias.grad[hidden_dim_slice],
)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
torch.testing.assert_close(
sharded_random_input.grad,
random_input.grad,
)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
batch_dim_slice = slice(
dist.get_rank(parallel_context.tp_pg) * sharded_batch_size,
(dist.get_rank(parallel_context.tp_pg) + 1) * sharded_batch_size,
)
torch.testing.assert_close(
sharded_random_input.grad,
random_input.grad[batch_dim_slice],
)
else:
ValueError(f"Unsupported mode: {tp_mode}")
parallel_context.destroy()
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_row_linear(
tp: int,
dp: int,
pp: int,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather")
init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(
tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
)
def _test_row_linear(
parallel_context: ParallelContext,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
out_features = 3
in_features_per_rank = 2
in_features = parallel_context.tp_pg.size() * in_features_per_rank
# Sharded
row_linear = TensorParallelRowLinear(
in_features=in_features,
out_features=out_features,
pg=parallel_context.tp_pg,
mode=tp_mode,
device="cuda",
async_communication=async_communication,
)
# Un-sharded
reference_linear = torch_nn.Linear(in_features=in_features, out_features=out_features, device="cuda")
# Copy weights/bias from sharded to un-sharded
with torch.inference_mode():
dist.all_reduce(tensor=reference_linear.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg)
row_linear.weight.copy_(
reference_linear.weight[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
)
# broadcast bias from rank 0, and the other don't have bias
if dist.get_rank(parallel_context.tp_pg) == 0:
row_linear.bias.copy_(reference_linear.bias)
dist.broadcast(
tensor=reference_linear.bias,
src=get_global_rank(group=parallel_context.tp_pg, group_rank=0),
group=parallel_context.tp_pg,
)
# Generate random input
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
batch_size = 5
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
batch_size = 5 * parallel_context.tp_pg.size()
else:
raise ValueError()
random_input = torch.randn(batch_size, in_features, device="cuda")
# synchronize random_input across tp
dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)
random_input.requires_grad = True
# Row linear receives as input sharded input
random_sharded_input = (
random_input[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
]
.detach()
.clone()
)
random_sharded_input.requires_grad = True
# Test that we get the same output after forward pass
# TODO @kunhao: We may want to have our custom error type
sharded_output = row_linear(random_sharded_input)
reference_output = reference_linear(random_input)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
sharded_reference_output = reference_output
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert batch_size % parallel_context.tp_pg.size() == 0
sharded_batch_size = batch_size // parallel_context.tp_pg.size()
sharded_reference_output = reference_output[
dist.get_rank(parallel_context.tp_pg)
* sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
* sharded_batch_size
]
else:
raise ValueError(f"Unsupported mode: {tp_mode}")
# TODO @thomasw21: Tune tolerance
torch.testing.assert_close(
sharded_output,
sharded_reference_output,
)
# Test that we get the same gradient after backward pass
sharded_output.sum().backward()
reference_output.sum().backward()
torch.testing.assert_close(
row_linear.weight.grad,
reference_linear.weight.grad[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
],
)
if dist.get_rank(parallel_context.tp_pg) == 0:
torch.testing.assert_close(
row_linear.bias.grad,
reference_linear.bias.grad,
)
else:
assert row_linear.bias is None
torch.testing.assert_close(
random_sharded_input.grad,
random_input.grad[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
],
)
parallel_context.destroy()
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@rerun_if_address_is_in_use()
def test_tensor_parallel_embedding(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode):
init_distributed(tp=tp, dp=dp, pp=pp)(_test_tensor_parallel_embedding)(tp_mode=tp_mode)
def _test_tensor_parallel_embedding(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode):
num_embeddings_per_rank = 100
embedding_dim = 3
num_embeddings = parallel_context.tp_pg.size() * num_embeddings_per_rank
# Sharded
sharded_embedding = TensorParallelEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
pg=parallel_context.tp_pg,
mode=tp_mode,
device="cuda",
)
# Un-sharded
reference_embedding = torch_nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, device="cuda")
# Copy weights/bias from sharded to un-sharded
with torch.inference_mode():
dist.all_reduce(tensor=reference_embedding.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg)
sharded_embedding.weight.copy_(
reference_embedding.weight[
dist.get_rank(parallel_context.tp_pg)
* num_embeddings_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* num_embeddings_per_rank,
:,
]
)
# Generate random input
random_input: torch.Tensor
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
batch_size = 5
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
batch_size = 5 * parallel_context.tp_pg.size()
else:
raise ValueError(f"Unsupported mode: {tp_mode}")
random_input = torch.randint(low=0, high=num_embeddings, size=(batch_size,), device="cuda")
dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg)
# Test that we get the same output after forward pass
sharded_output = sharded_embedding(random_input)
reference_output = reference_embedding(random_input)
weights = torch.arange(batch_size, device="cuda")[:, None]
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
sharded_reference_output = reference_output
sharded_weights = weights
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert batch_size % parallel_context.tp_pg.size() == 0
sharded_batch_size = batch_size // parallel_context.tp_pg.size()
sharded_reference_output = reference_output[
dist.get_rank(parallel_context.tp_pg)
* sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
* sharded_batch_size
]
sharded_weights = weights[
dist.get_rank(parallel_context.tp_pg)
* sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
* sharded_batch_size
]
else:
raise ValueError(f"Unsupported mode: {tp_mode}")
# TODO @thomasw21: Tune tolerance
torch.testing.assert_close(sharded_output, sharded_reference_output, atol=0, rtol=0)
# Test that we get the same gradient after backward pass
(sharded_output * sharded_weights).sum().backward()
(reference_output * weights).sum().backward()
torch.testing.assert_close(
sharded_embedding.weight.grad,
reference_embedding.weight.grad[
dist.get_rank(parallel_context.tp_pg)
* num_embeddings_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* num_embeddings_per_rank,
:,
],
atol=0,
rtol=0,
)
parallel_context.destroy()
import torch
from helpers.distributed_tensor import assert_tensor_equal_over_group
from helpers.exception import assert_fail_with
from helpers.utils import init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.tied_parameters import (
get_tied_id_to_param,
sync_tied_weights_gradients,
tie_parameters,
)
from torch import nn
@rerun_if_address_is_in_use()
def test_tie_weight_in_same_device():
init_distributed(tp=1, dp=1, pp=1)(_test_tie_weight_in_same_device)()
def _test_tie_weight_in_same_device(parallel_context: ParallelContext):
model = nn.ModuleDict({"dense0": nn.Linear(10, 10, device="cuda"), "dense1": nn.Linear(10, 10, device="cuda")})
# Tie weights/bias
tie_parameters(
root_module=model,
ties=[("dense0.weight", (0,)), ("dense1.weight", (0,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
tie_parameters(
root_module=model,
ties=[("dense0.bias", (0,)), ("dense1.bias", (0,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
weight0 = model.get_parameter("dense0.weight")
weight1 = model.get_parameter("dense1.weight")
bias0 = model.get_parameter("dense0.bias")
bias1 = model.get_parameter("dense1.bias")
# We check that we use the same parameter for both linear layers
assert id(weight0) == id(weight1)
assert id(bias0) == id(bias1)
parallel_context.destroy()
@rerun_if_address_is_in_use()
def test_tie_weight_in_different_device():
init_distributed(tp=1, dp=1, pp=2)(_test_tie_weight_in_different_device)()
def _test_tie_weight_in_different_device(parallel_context: ParallelContext):
if dist.get_rank(parallel_context.pp_pg) == 0:
model = nn.ModuleDict(
{
"dense0": nn.Linear(10, 10, device="cuda"),
}
)
else:
model = nn.ModuleDict(
{
"dense1": nn.Linear(10, 10, device="cuda"),
}
)
# Tie weights/bias
tie_parameters(
root_module=model,
ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
tie_parameters(
root_module=model,
ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
group = parallel_context.world_ranks_to_pg[(0, 1)]
# Check that model weights are not in fact synchronized
if dist.get_rank(parallel_context.pp_pg) == 0:
weight = model.dense0.weight
bias = model.dense0.bias
else:
weight = model.dense1.weight
bias = model.dense1.bias
# Make sure that weight/bias are NanotronParameter and that they are tied
assert isinstance(weight, NanotronParameter)
assert weight.is_tied
assert isinstance(bias, NanotronParameter)
assert bias.is_tied
# Weights/bias are not synced yet
assert not assert_tensor_equal_over_group(weight, group=group, assert_=False)
assert not assert_tensor_equal_over_group(bias, group=group, assert_=False)
# Manually sync weights
for (_, group_ranks), param in sorted(
get_tied_id_to_param(
parameters=model.parameters(),
root_module=model,
).items(),
key=lambda x: x[0],
):
group = parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
# We check that we use the same parameter for both linear layers
assert_tensor_equal_over_group(weight, group=group)
assert_tensor_equal_over_group(bias, group=group)
parallel_context.destroy()
@rerun_if_address_is_in_use()
def test_tie_weight_across_dp_is_impossible():
init_distributed(tp=1, dp=2, pp=1)(_test_tie_weight_across_dp_is_impossible)()
def _test_tie_weight_across_dp_is_impossible(parallel_context: ParallelContext):
if dist.get_rank(parallel_context.dp_pg) == 0:
model = nn.ModuleDict(
{
"dense0": nn.Linear(10, 10, device="cuda"),
}
)
else:
model = nn.ModuleDict(
{
"dense1": nn.Linear(10, 10, device="cuda"),
}
)
# Tie weights/bias
with assert_fail_with(AssertionError):
tie_parameters(
root_module=model,
ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
with assert_fail_with(AssertionError):
tie_parameters(
root_module=model,
ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
parallel_context.destroy()
@rerun_if_address_is_in_use()
def test_tie_weight_in_different_device_have_gradients_synchronized():
init_distributed(tp=1, dp=1, pp=2)(_test_tie_weight_in_different_device_have_gradients_synchronized)()
def _test_tie_weight_in_different_device_have_gradients_synchronized(parallel_context: ParallelContext):
if dist.get_rank(parallel_context.pp_pg) == 0:
model = nn.ModuleDict(
{
"dense0": nn.Linear(10, 10, device="cuda"),
}
)
else:
model = nn.ModuleDict(
{
"dense1": nn.Linear(10, 10, device="cuda"),
}
)
# Tie weights/bias
tie_parameters(
root_module=model,
ties=[("dense0.weight", (0,)), ("dense1.weight", (1,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
tie_parameters(
root_module=model,
ties=[("dense0.bias", (0,)), ("dense1.bias", (1,))],
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
group = parallel_context.world_ranks_to_pg[(0, 1)]
# Check that model weights are not in fact synchronized
if dist.get_rank(parallel_context.pp_pg) == 0:
weight = model.dense0.weight
bias = model.dense0.bias
else:
weight = model.dense1.weight
bias = model.dense1.bias
# Make sure that weight/bias are NanotronParameter and that they are tied
assert isinstance(weight, NanotronParameter)
assert weight.is_tied
assert isinstance(bias, NanotronParameter)
assert bias.is_tied
# Weights/bias are not synced yet
assert not assert_tensor_equal_over_group(weight, group=group, assert_=False)
assert not assert_tensor_equal_over_group(bias, group=group, assert_=False)
# Compute gradient
input_ = torch.randn(13, 10, device="cuda")
if dist.get_rank(parallel_context.pp_pg) == 0:
out = model.dense0(input_)
else:
out = model.dense1(input_)
out.sum().backward()
# sync gradients
# TODO @thomasw21: This should be done in hooks
sync_tied_weights_gradients(model, parallel_context=parallel_context, grad_accumulator=None)
# Check that we have gradient
assert weight.grad is not None
assert bias.grad is not None
# We check that we both gradients are synchronized
assert_tensor_equal_over_group(weight.grad, group=group)
assert_tensor_equal_over_group(bias.grad, group=group)
parallel_context.destroy()
import os
import pytest
import torch
from helpers.distributed_tensor import assert_tensor_equal_over_group
from helpers.dummy import dummy_infinite_data_loader, init_dummy_model
from helpers.exception import assert_fail_with
from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use
from nanotron import distributed as dist
from nanotron.optim import NamedOptimizer, ZeroDistributedOptimizer
from nanotron.optim.zero import SlicedFlatTensor
from nanotron.parallel import ParallelContext
from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel import nn
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tied_parameters import sync_tied_weights_gradients
from nanotron.random import RandomStates, branch_random_state, get_current_random_state, get_synced_random_state
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(1, i, 1) for i in range(1, min(4, available_gpus()) + 1)])
@rerun_if_address_is_in_use()
def test_zero_optimizer(tp: int, dp: int, pp: int):
init_distributed(pp=pp, dp=dp, tp=tp)(_test_zero_optimizer)()
def _test_zero_optimizer(parallel_context: ParallelContext):
model = init_dummy_model(parallel_context=parallel_context)
optimizer = ZeroDistributedOptimizer(
named_params_or_groups=model.named_parameters(),
optimizer_builder=lambda named_param_groups: NamedOptimizer(
named_params_or_groups=named_param_groups,
optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups),
),
dp_pg=parallel_context.dp_pg,
)
index_to_name = [name for name, _ in model.named_parameters()]
# reference model
reference_model = init_dummy_model(parallel_context=parallel_context)
reference_optimizer = torch.optim.AdamW(reference_model.parameters())
# sync weights between reference_model and model
with torch.no_grad():
for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()):
assert name == ref_name
param.copy_(ref_param)
# Get infinite dummy data iterator
data_loader = iter(dummy_infinite_data_loader(pp_pg=parallel_context.pp_pg))
nb_optim_steps = 3
batches = [[next(data_loader)] for _ in range(nb_optim_steps)]
pipeline_engine = AllForwardAllBackwardPipelineEngine()
# Training loop
for i, batch in enumerate(batches):
# store original reference parameter
old_named_params = {name: param.detach().clone() for name, param in model.named_parameters()}
# Run forward/backward
losses = pipeline_engine.train_batch_iter(
model=model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=1, grad_accumulator=None
)
ref_losses = pipeline_engine.train_batch_iter(
model=reference_model, pg=parallel_context.pp_pg, batch=batch, nb_microbatches=1, grad_accumulator=None
)
# Check loss match
losses = list(losses)
ref_losses = list(ref_losses)
assert len(losses) == len(ref_losses)
for loss, ref_loss in zip(losses, ref_losses):
assert isinstance(loss["loss"], torch.Tensor)
assert isinstance(ref_loss["loss"], torch.Tensor)
torch.testing.assert_close(
loss["loss"], ref_loss["loss"], atol=0, rtol=0, msg=lambda msg: f"At iteration {i}, {msg}"
)
# Manually sync tied parameters' gradients
sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None)
sync_tied_weights_gradients(module=reference_model, parallel_context=parallel_context, grad_accumulator=None)
# We rely on DDP to synchronize gradients across DP. We only need to manually synchronize them if we don't use DDP.
if not isinstance(model, DistributedDataParallel):
sync_gradients_across_dp(
model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None
)
if not isinstance(reference_model, DistributedDataParallel):
sync_gradients_across_dp(
reference_model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None
)
# Check gradients are synced across DP
for name, param in model.named_parameters():
assert_tensor_equal_over_group(param.grad, group=parallel_context.dp_pg)
for ref_name, ref_param in reference_model.named_parameters():
assert_tensor_equal_over_group(ref_param.grad, group=parallel_context.dp_pg)
# Check gradients are the same with reference_model
for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()):
assert name == ref_name
torch.testing.assert_close(
param.grad, ref_param.grad, atol=0, rtol=0, msg=lambda msg: f"At iteration {i}, {msg}"
)
assert len(optimizer.param_groups) == 1
assert len(list(model.named_parameters())) == len(optimizer.param_groups[0]["params"])
with torch.no_grad():
for (name, param), sliced_param in zip(model.named_parameters(), optimizer.param_groups[0]["params"]):
offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)]
# Check that weights are the same
expected_slice = param.view(-1)[slice(*offsets)].view_as(sliced_param)
torch.testing.assert_close(
expected_slice,
sliced_param,
atol=0,
rtol=0,
msg=lambda msg: f"Weights don't match: {msg}\n - Expected slice: {expected_slice}\n - Got: {sliced_param}\n - Full gradient: {param}",
)
assert (
expected_slice.data_ptr() == sliced_param.data_ptr()
), "Parameters should actually share the same data pointer"
# Check gradients is the view
expected_slice = param.grad.view(-1)[slice(*offsets)].view_as(sliced_param.grad)
assert (
expected_slice.data_ptr() == sliced_param.grad.data_ptr()
), "Parameters should actually share the same data pointer"
torch.testing.assert_close(
expected_slice,
sliced_param.grad,
atol=0,
rtol=0,
msg=lambda msg: f"Gradients don't match: {msg}\n - Expected slice: {expected_slice}\n - Got: {sliced_param.grad}\n - Full gradient: {param.grad}",
)
# Optimizer steps
optimizer.step()
optimizer.zero_grad()
reference_optimizer.step()
reference_optimizer.zero_grad()
# Check that params are synced across DP
for name, param in model.named_parameters():
assert_tensor_equal_over_group(param, group=parallel_context.dp_pg)
assert param.grad is None
# Check that gradients are reset
for ref_name, ref_param in reference_model.named_parameters():
assert_tensor_equal_over_group(ref_param, group=parallel_context.dp_pg)
assert ref_param.grad is None
for param_group in optimizer.param_groups:
for param in param_group["params"]:
assert param.grad is None
# Check params are the same with reference_model
for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()):
assert name == ref_name
# TODO @thomasw21: Figure out how to make this pass at `atol`/`rtol` set to 0.
torch.testing.assert_close(param, ref_param, msg=lambda msg: f"At iteration {i}, {msg}")
# Check params have been updated correctly
for (name, param) in model.named_parameters():
old_param = old_named_params[name]
assert not torch.allclose(param, old_param)
# We need to check that the optimizer states are the same
state_dict = optimizer.state_dict()
reference_state_dict = reference_optimizer.state_dict()
state = state_dict["state"]
ref_state = reference_state_dict["state"]
assert set(state) == set(ref_state)
for index, optim_state in state.items():
ref_optim_state = ref_state[index]
name = index_to_name[index]
offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)]
assert set(optim_state) == set(ref_optim_state)
for key in ["exp_avg", "exp_avg_sq"]:
value = optim_state[key]
ref_value = ref_optim_state[key]
torch.testing.assert_close(
value,
ref_value.view(-1)[slice(*offsets)].view_as(value),
atol=0,
rtol=0,
msg=lambda msg: f"At iteration {i}, {msg}",
)
parallel_context.destroy()
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(2, i, 1) for i in range(1, available_gpus() // 2 + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@rerun_if_address_is_in_use()
def test_zero_optimizer_with_tp(
tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
init_distributed(pp=pp, dp=dp, tp=tp)(_test_zero_optimizer_with_tp)(
tp_mode=tp_mode, async_communication=async_communication
)
def _test_zero_optimizer_with_tp(
parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
model = torch_nn.Sequential(
nn.TensorParallelColumnLinear(
in_features=5,
out_features=parallel_context.tp_pg.size(),
mode=tp_mode,
pg=parallel_context.tp_pg,
device="cuda",
async_communication=async_communication,
),
# We choose `sigmoid` instead of `relu` since `relu` can result in a sparse gradient, causing no update to certain parameters
torch_nn.Sigmoid(),
nn.TensorParallelRowLinear(
in_features=parallel_context.tp_pg.size(),
out_features=3,
mode=tp_mode,
pg=parallel_context.tp_pg,
device="cuda",
),
)
optimizer = ZeroDistributedOptimizer(
named_params_or_groups=model.named_parameters(),
optimizer_builder=lambda named_param_groups: NamedOptimizer(
named_params_or_groups=named_param_groups,
optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups),
),
dp_pg=parallel_context.dp_pg,
)
optimizer_name_to_id = {v: k for k, v in optimizer.optimizer.id_to_name.items()}
assert len(optimizer_name_to_id) == len(optimizer.id_to_name)
# reference model
reference_model = torch_nn.Sequential(
torch_nn.Linear(in_features=5, out_features=parallel_context.tp_pg.size(), device="cuda"),
torch_nn.Sigmoid(),
torch_nn.Linear(in_features=parallel_context.tp_pg.size(), out_features=3, device="cuda"),
)
for module in reference_model.modules():
for name, param in module.named_parameters(recurse=False):
setattr(module, name, NanotronParameter(param))
reference_optimizer = torch.optim.AdamW(reference_model.parameters())
# TODO @thomasw21: This is a hack to obtain `AdamW` index in it's state.
name_to_index = {name: index for index, (name, _) in enumerate(reference_model.named_parameters())}
# sync parameters
with torch.no_grad():
for ref_name, ref_param in reference_model.named_parameters():
dist.all_reduce(ref_param, op=dist.ReduceOp.AVG, group=parallel_context.world_pg)
for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()):
assert name == ref_name
assert isinstance(param, NanotronParameter)
if param.is_sharded:
sharded_info = param.get_sharded_info()
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
local_slices = local_global_slices_pair.local_slices
global_slices = local_global_slices_pair.global_slices
param[local_slices].copy_(ref_param[global_slices])
else:
param.copy_(ref_param)
# Get infinite dummy data iterator, it has to be synced across TP
random_states = RandomStates(
{
"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg),
}
)
batch_size = 2 * parallel_context.tp_pg.size() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER else 7
with branch_random_state(random_states=random_states, key="tp_synced", enabled=True):
nb_optim_steps = 3
batches = [
torch.randn(batch_size, 5, device="cuda")
if dist.get_rank(parallel_context.pp_pg) == 0
else TensorPointer(0)
for _ in range(nb_optim_steps)
]
# Model training loop
for i, batch in enumerate(batches):
# store original reference parameter
old_named_params = {name: param.detach().clone() for name, param in model.named_parameters()}
# Run forward pass
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
batch_size = batch.shape[0]
assert batch_size % parallel_context.tp_pg.size() == 0
step = batch_size // parallel_context.tp_pg.size()
loss = model(
batch[
dist.get_rank(parallel_context.tp_pg) * step : (dist.get_rank(parallel_context.tp_pg) + 1) * step
]
)
else:
loss = model(batch)
ref_loss = reference_model(batch)
# Run backward pass
loss.sum().backward()
ref_loss.sum().backward()
# Check loss is the same
loss = loss.detach()
ref_loss = ref_loss.detach()
assert isinstance(loss, torch.Tensor)
assert isinstance(ref_loss, torch.Tensor)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
batch_size = batch.shape[0]
assert batch_size % parallel_context.tp_pg.size() == 0
step = batch_size // parallel_context.tp_pg.size()
torch.testing.assert_close(
loss,
ref_loss[
dist.get_rank(parallel_context.tp_pg) * step : (dist.get_rank(parallel_context.tp_pg) + 1) * step
],
msg=lambda msg: f"At iteration {i}, {msg}",
)
else:
torch.testing.assert_close(loss, ref_loss, msg=lambda msg: f"At iteration {i}, {msg}")
# Manually sync tied parameters
sync_tied_weights_gradients(module=model, parallel_context=parallel_context, grad_accumulator=None)
sync_tied_weights_gradients(module=reference_model, parallel_context=parallel_context, grad_accumulator=None)
# We rely on DDP to synchronize gradients across DP. We only need to manually synchronize them if we don't use DDP.
if not isinstance(model, DistributedDataParallel):
sync_gradients_across_dp(
model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None
)
if not isinstance(reference_model, DistributedDataParallel):
sync_gradients_across_dp(
reference_model, dp_pg=parallel_context.dp_pg, reduce_op=dist.ReduceOp.AVG, grad_accumulator=None
)
# Check gradients are synced across DP
for name, param in model.named_parameters():
assert_tensor_equal_over_group(param.grad, group=parallel_context.dp_pg)
for ref_name, ref_param in reference_model.named_parameters():
assert_tensor_equal_over_group(ref_param.grad, group=parallel_context.dp_pg)
# Check gradients are the same with reference_model
for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()):
assert name == ref_name
if param.is_sharded:
sharded_info = param.get_sharded_info()
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
local_slices = local_global_slices_pair.local_slices
global_slices = local_global_slices_pair.global_slices
torch.testing.assert_close(
param.grad[local_slices],
ref_param.grad[global_slices],
msg=lambda msg: f"At iteration {i}, {msg}",
)
else:
torch.testing.assert_close(param.grad, ref_param.grad, msg=lambda msg: f"At iteration {i}, {msg}")
with torch.no_grad():
optim_param_id_to_param = {id(param): param for param in optimizer.param_groups[0]["params"]}
assert len(optim_param_id_to_param) == len(optimizer.param_groups[0]["params"])
for name, param in model.named_parameters():
if dist.get_rank(parallel_context.dp_pg) not in optimizer.param_name_to_dp_rank_offsets[name]:
assert name not in optimizer_name_to_id
continue
param_id = optimizer_name_to_id[name]
sliced_param = optim_param_id_to_param[param_id]
offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)]
# Check that weights share the same storage
expected_slice = param.view(-1)[slice(*offsets)].view_as(sliced_param)
torch.testing.assert_close(
expected_slice,
sliced_param,
atol=0,
rtol=0,
msg=lambda msg: f"At iteration {i}, weights don't match: {msg}\n - Expected slice: {expected_slice}\n - Got: {sliced_param}\n - Full gradient: {param}",
)
assert (
expected_slice.data_ptr() == sliced_param.data_ptr()
), "Parameters should actually share the same data pointer"
# Check that gradients share the same storage
expected_slice = param.grad.view(-1)[slice(*offsets)].view_as(sliced_param.grad)
assert (
expected_slice.data_ptr() == sliced_param.grad.data_ptr()
), "Parameters should actually share the same data pointer"
torch.testing.assert_close(
expected_slice,
sliced_param.grad,
atol=0,
rtol=0,
msg=lambda msg: f"At iteration {i}, gradients don't match: {msg}\n - Expected slice: {expected_slice}\n - Got: {sliced_param.grad}\n - Full gradient: {param.grad}",
)
# Optimizer steps
optimizer.step()
optimizer.zero_grad()
reference_optimizer.step()
reference_optimizer.zero_grad()
# Check that params are synced across DP
for name, param in model.named_parameters():
assert_tensor_equal_over_group(param, group=parallel_context.dp_pg)
assert param.grad is None
# Check that gradients are reset
for ref_name, ref_param in reference_model.named_parameters():
assert_tensor_equal_over_group(ref_param, group=parallel_context.dp_pg)
assert ref_param.grad is None
for param_group in optimizer.param_groups:
for param in param_group["params"]:
assert param.grad is None
# Check params are the same with reference_model
for (name, param), (ref_name, ref_param) in zip(model.named_parameters(), reference_model.named_parameters()):
assert name == ref_name
if param.is_sharded:
sharded_info = param.get_sharded_info()
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
local_slices = local_global_slices_pair.local_slices
global_slices = local_global_slices_pair.global_slices
torch.testing.assert_close(
param[local_slices], ref_param[global_slices], msg=lambda msg: f"At iteration {i}, {msg}"
)
else:
torch.testing.assert_close(param, ref_param, msg=lambda msg: f"At iteration {i}, {msg}")
# Check params have been updated correctly:
for (name, param) in model.named_parameters():
old_param = old_named_params[name]
assert not torch.allclose(param, old_param)
# We need to check that the optimizer states are the same
state_dict = optimizer.state_dict()
reference_state_dict = reference_optimizer.state_dict()
state = state_dict["state"]
ref_state = reference_state_dict["state"]
assert "names" in state_dict
state_index_to_name = state_dict["names"]
state_name_to_index = {name: index for index, name in state_index_to_name.items()}
# Check that this is a bijection
assert len(state_index_to_name) == len(state_name_to_index)
for name, param in model.named_parameters():
if name not in state_name_to_index:
# Parameters is not passed to optimizer, mainly due to zero sharding strategy
continue
index = state_name_to_index[name]
optim_state = state[index]
ref_optim_state = ref_state[name_to_index[name]]
offsets = optimizer.param_name_to_dp_rank_offsets[name][dist.get_rank(parallel_context.dp_pg)]
assert set(optim_state) == set(ref_optim_state)
assert isinstance(param, NanotronParameter)
for key in ["exp_avg", "exp_avg_sq"]:
value = optim_state[key]
ref_value = ref_optim_state[key]
if param.is_sharded:
sharded_info = param.get_sharded_info()
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
global_slices = local_global_slices_pair.global_slices
torch.testing.assert_close(
# TODO @thomasw21: We can't add any information about `local_slices` to `value` because it's already flattened
# For now, we're going to assume that sharded parameters are contiguous, and `local_slices` are trivial all none slice
value,
ref_value[global_slices].view(-1)[slice(*offsets)],
msg=lambda msg: f"At iteration {i}, {msg}",
)
else:
torch.testing.assert_close(
value,
ref_value.view(-1)[slice(*offsets)].view_as(value),
msg=lambda msg: f"At iteration {i}, {msg}",
)
parallel_context.destroy()
@rerun_if_address_is_in_use()
def test_sliced_flat_tensor():
init_distributed(1, 1, 1)(_test_sliced_flat_tensor)()
def _test_sliced_flat_tensor(parallel_context: ParallelContext):
a = torch.randn(2, 3, requires_grad=True)
grad = torch.randn(2, 3)
a.grad = grad
start_offset, end_offset = 1, 5
b = SlicedFlatTensor(a, start_offset=start_offset, end_offset=end_offset)
torch.testing.assert_close(a.grad, grad, atol=0, rtol=0)
torch.testing.assert_close(b.grad, grad.view(-1)[start_offset:end_offset])
# Deallocate the gradient by setting it to None
a.grad = None
assert a.grad is None
assert b.grad is None
# Setting gradient to None on the sliced tensor works
a.grad = grad
assert a.grad is not None
assert b.grad is not None
b.grad = None
assert b.grad is None
assert a.grad is None
with assert_fail_with(NotImplementedError):
b.grad = torch.randn(1, 5)
with assert_fail_with(NotImplementedError):
del b.grad
c = b[:3]
# It's important not to contaminate everyone.
assert not isinstance(c, SlicedFlatTensor)
parallel_context.destroy()
"""
To process HuggingFace Datasets:
python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/emotion --n-tasks 16 hf --dataset dair-ai/emotion
To process Jsonl files:
python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --output-folder datasets/c4-es --n-tasks 16 jsonl --dataset raw_datasets/c4-es-json-files
"""
import argparse
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader
from datatrove.pipeline.tokens import DocumentTokenizer
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="Tokenizer")
group.add_argument(
"--tokenizer-name-or-path",
type=str,
required=True,
help="A path to a directory containing vocabulary files required by the tokenizer or the model id of a predefined tokenizer hosted inside a model repo on the Hugging Face Hub.",
)
group.add_argument(
"--eos-token",
type=str,
default=None,
help="EOS token to add after each document. Default: None",
)
group = parser.add_argument_group(title="Output data")
group.add_argument(
"--output-folder", type=str, required=True, help="Path to the output folder to store the tokenized documents"
)
group = parser.add_argument_group(title="Miscellaneous configs")
group.add_argument(
"--logging-dir",
type=str,
default=None,
help="Path to a folder for storing the logs of the preprocessing step. Default: None",
)
group.add_argument(
"--n-tasks", type=int, default=8, help="Total number of tasks to run the preprocessing step. Default: 8"
)
# Subparsers for processing either Hugging Face datasets or jsonl files
sp = parser.add_subparsers(
dest="readers",
required=True,
description="Type of dataset to process. It can be either a Hugging Face Dataset loaded with datasets.load_data ('hf') or a .jsonl dataset ('jsonl')",
)
p1 = sp.add_parser(name="hf")
p1.add_argument(
"--dataset",
type=str,
required=True,
help="Path to local stored dataset or repository on the Hugging Face hub that can be loaded with datasets.load_dataset",
)
p1.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text")
p1.add_argument("--split", type=str, default="train", help="Which split of the data to process. Default: train")
p2 = sp.add_parser(name="jsonl")
p2.add_argument(
"--dataset",
type=str,
required=True,
help="Path to a .jsonl file or a folder containing multiple .jsonl files",
)
p2.add_argument("--column", type=str, default="text", help="Column to preprocess from the Dataset. Default: text")
p2.add_argument(
"--glob-pattern", type=str, default=None, help="A glob pattern to filter files to read. Default: None"
)
args = parser.parse_args()
return args
def main(args):
# Build datatrove reader
if args.readers == "hf":
datatrove_reader = HuggingFaceDatasetReader(
dataset=args.dataset,
text_key=args.column,
dataset_options={"split": args.split},
)
else:
datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern)
preprocess_executor = LocalPipelineExecutor(
pipeline=[
datatrove_reader,
DocumentTokenizer(
output_folder=args.output_folder,
tokenizer_name_or_path=args.tokenizer_name_or_path,
eos_token=args.eos_token,
shuffle=False,
max_tokens_per_file=1e9,
),
],
tasks=args.n_tasks,
logging_dir=args.logging_dir,
)
preprocess_executor.run()
if __name__ == "__main__":
_args = get_args()
main(_args)
# --nproc_per_node=8:dp=2, pp=2, and tp=2
# --nproc_per_node=4:dp=1, pp=2, and tp=2
# --nproc_per_node=1:dp=1, pp=1, and tp=1
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/config_tiny_llama.yaml
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/config_tiny_llama_cosmo2tokenizer.yaml
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/config_llama3_dummytokenizer.yaml
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 run_train.py --config-file examples/config_llama3.yaml
# CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config-file smollm1/config_smollm1_135M_demo1.yaml
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config-file smollm1/config_smollm1_135M_demo2.yaml
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment