Unverified Commit e8a17d1e authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Move FusedAdam/FusedSGD and necessary kernels from Apex to TE (#867)



* add multi-tensor kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add FusedAdam
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add test to qa
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add FusedSGD
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b1a0e0a7
...@@ -20,3 +20,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py ...@@ -20,3 +20,5 @@ pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
...@@ -435,7 +435,11 @@ def setup_common_extension() -> CMakeExtension: ...@@ -435,7 +435,11 @@ def setup_common_extension() -> CMakeExtension:
) )
def _all_files_in_dir(path): def _all_files_in_dir(path):
return list(path.iterdir()) all_files = []
for dirname, _, names in os.walk(path):
for name in names:
all_files.append(Path(dirname, name))
return all_files
def setup_pytorch_extension() -> setuptools.Extension: def setup_pytorch_extension() -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support""" """Setup CUDA extension for PyTorch support"""
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from itertools import product
import unittest
import copy
import torch
from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
class TestFusedOptimizer(unittest.TestCase):
def setUp(self, iters=7):
self.iters = iters
torch.manual_seed(9876)
def tearDown(self):
pass
def gen_param_optim(self, tensors, options, tst_options=None):
# Adding this to make backward compatible with existing tests. Just in
# case "tst_options" are not provided, it gets a copy of options
# which contains the parameters for the reference optimizer
if tst_options == None:
tst_options = options
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = self.ref_optim(ref_param, **options)
tst_optim = self.fused_optim(tst_param, **tst_options)
return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
for p_ref, p_tst in zip(ref_param, tst_param):
half_grads.append(torch.rand_like(p_ref).half())
p_ref.grad = half_grads[-1].float() / scale
return half_grads
def gen_single_type_test(
self, param_type=torch.float, device="cuda", *, skip_assert: bool = False
):
nelem = 278011
# Some ref and test optimizers may require different set of options.
# This is a quick workaround to add that functionality while making
# minimum changes in existing code.
# If there is no "tst_options" field provided, safe to initialize
# the test optimizer with the parameters of reference optimizer.
if not hasattr(self, "tst_options"):
self.tst_options = self.options
tensor = torch.rand(nelem, dtype=param_type, device=device)
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], self.options, self.tst_options
)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
if skip_assert:
return
torch.testing.assert_close(ref_param, tst_param)
class TestFusedAdam(TestFusedOptimizer):
def setUp(self):
super().setUp()
self.options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
self.ref_optim = torch.optim.Adam
self.fused_optim = te.optimizers.FusedAdam
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
# NOTE(mkozuki): Current threshold values look too small for BFloat16.
# TODO(mkozuki): Refactor `TestFusedOptimizer`
def test_half(self):
self.gen_single_type_test(param_type=torch.float16, skip_assert=True)
def test_bfloat16(self):
self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
@unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
tensors, self.options
)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_param, tst_param)
def test_adam_option(self):
nelem = 1
adam_option = {
"lr": 0.01,
"betas": (0.6, 0.9),
"eps": 3e-06,
"weight_decay": 0,
"amsgrad": False,
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adam_option
)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_param, tst_param)
def test_frozen_model(self):
nelem = 1
adam_option = {
"lr": 0.01,
"betas": (0.6, 0.9),
"eps": 3e-06,
"weight_decay": 0,
"amsgrad": False,
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
[tensor], adam_option
)
# Add an empty param group which may occur for pipeline parallel p-tuning
tst_optim.add_param_group({"params": []})
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_param, tst_param)
class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedSGD, self).__init__(*args, **kwargs)
self.options = {"lr": .25, "momentum": .125}
self.ref_optim = torch.optim.SGD
self.fused_optim = te.optimizers.FusedSGD
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.reshape(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
class AdamTest(unittest.TestCase):
def setUp(self, seed=0):
super().setUp()
torch.manual_seed(seed)
self.model = Model().cuda()
self.model_ = Model().cuda()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
self.lr = 0.00001
params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.Adam(params, lr=self.lr)
def testGradScaler(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
scaler_.scale(loss_).backward()
scaler_.step(optimizer_)
scaler_.update()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad,
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
def testGradScalerCapturable(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
scaler_.scale(loss_).backward()
scaler_.step(optimizer_)
scaler_.update()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad,
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
def testGradScalerCapturableMaster(self):
# Cast conv layers to FP16
for m in self.model_.modules():
if m.__class__ in [torch.nn.Conv2d]:
m.half()
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(
params_, lr=self.lr, capturable=True, master_weights=True
)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
scaler_.scale(loss_).backward()
scaler_.step(optimizer_)
scaler_.update()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight,
m_.weight.float(),
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad.float(),
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
def testNative(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()
# Reference
y = self.model(x)
loss = ((gt - y) ** 2).mean()
loss.backward()
self.optimizer.step()
# DUT
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
loss_.backward()
optimizer_.step()
for module in zip(self.model.modules(), self.model_.modules()):
m = module[0]
m_ = module[1]
if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
torch.testing.assert_close(
m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True
)
torch.testing.assert_close(
m.weight.grad,
m_.weight.grad,
atol=1e-3,
rtol=1e-3,
equal_nan=True,
)
# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
@largeTensorTest("60GB", "cuda")
def testLargeTensor(self):
t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
grad = torch.randn_like(t)
t.grad = grad
t2.grad = grad
params = [t]
params2 = [t2]
optimizer = te.optimizers.FusedAdam(params, lr=self.lr)
optimizer.step()
optimizer2 = torch.optim.Adam(params2, lr=self.lr)
torch.testing.assert_close(t, t2)
torch.cuda.synchronize()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply
input_size_pairs = [
(7777 * 77, 555 * 555),
(777, 555),
(555, 2048 * 32 + 1),
(2048 * 32 + 1, 555),
(555, 2048 * 32),
(2048 * 32, 555),
(33333, 555),
(555, 33333),
]
appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)]
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("inplace", [False, True])
def test_multi_tensor_scale(
input_size_pair, applier, repeat, in_type, out_type, inplace
):
if inplace is True and (out_type is not in_type):
pytest.skip("inplace=True and out_type != in_type is not supported.")
elif (in_type == torch.float16 and out_type == torch.bfloat16) or (
in_type == torch.bfloat16 and out_type == torch.float16
):
pytest.skip("float16 to bfloat16 is not necessary and vice versa.")
device = torch.device("cuda")
scale = 4.0
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
ref = torch.tensor([1.0], dtype=torch.float32, device=device)
sizea, sizeb = input_size_pair
def downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=False):
overflow_buf.zero_()
a = torch.full([sizea], scale, dtype=torch.float32, device=device)
b = torch.full([sizeb], scale, dtype=torch.float32, device=device)
out_list = []
for i in range(repeat):
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
else:
in_list = [out.clone().to(in_type) for out in out_list]
applier(tex.multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
assert all([torch.allclose(out, ref.to(out_type)) for out in out_list])
assert overflow_buf.item() == 0
def find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
t,
ind,
val,
inplace=False,
):
overflow_buf.zero_()
a = torch.full([sizea], scale, dtype=torch.float32, device=device)
b = torch.full([sizeb], scale, dtype=torch.float32, device=device)
out_list = []
for i in range(repeat):
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
else:
in_list = [out.clone().to(in_type) for out in out_list]
applier(tex.multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
overflow_buf.zero_()
in_list[t][ind] = val
applier(tex.multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
assert overflow_buf.item() > 0
downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
0,
0,
float("nan"),
inplace=inplace,
)
find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
2 * repeat - 1,
sizeb - 1,
float("inf"),
inplace=inplace,
)
find_inf(
sizea,
sizeb,
applier,
repeat,
in_type,
out_type,
2 * (repeat // 2),
sizea // 2,
float("inf"),
inplace=inplace,
)
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True])
def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tensor):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
val = 4.0
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
overflow_buf.zero_()
a = torch.full([sizea], val, dtype=torch.float32, device=device)
b = torch.full([sizeb], val, dtype=torch.float32, device=device)
in_list = []
for i in range(repeat):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm, norm_per_tensor = applier(
tex.multi_tensor_l2norm, overflow_buf, [in_list], True
)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
).norm()
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
torch.testing.assert_close(
norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
)
assert overflow_buf.item() == 0
@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True])
def test_multi_tensor_unscale_l2norm(
input_size_pair, applier, repeat, in_type, per_tensor
):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
val = 4.0
inv_scale = 0.5
inv_scale_cuda = torch.tensor([inv_scale], dtype=torch.float32, device=device)
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
overflow_buf.zero_()
a = torch.full([sizea], val, dtype=torch.float32, device=device)
b = torch.full([sizeb], val, dtype=torch.float32, device=device)
in_list = []
for i in range(repeat):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
norm, norm_per_tensor = applier(
tex.multi_tensor_unscale_l2norm,
overflow_buf,
[in_list],
inv_scale_cuda,
True,
)
normab = torch.cat(
((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1))
)
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(
tex.multi_tensor_unscale_l2norm,
overflow_buf,
[in_list],
inv_scale_cuda,
True,
)
reference = torch.full(
[(sizea + sizeb) * repeat], val * inv_scale, dtype=torch.float32, device=device
).norm()
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
torch.testing.assert_close(
norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
)
assert overflow_buf.item() == 0
...@@ -21,6 +21,7 @@ from .export import onnx_export ...@@ -21,6 +21,7 @@ from .export import onnx_export
from .distributed import checkpoint from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context from .cpu_offload import get_cpu_offload_context
from . import optimizers
# Register custom op symbolic ONNX functions # Register custom op symbolic ONNX functions
from .te_onnx_extensions import ( from .te_onnx_extensions import (
onnx_cast_to_fp8, onnx_cast_to_fp8,
......
...@@ -691,3 +691,44 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, ...@@ -691,3 +691,44 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens,
int world_size, int world_size,
int rank int rank
); );
/***************************************************************************************************
* multi_tensor_* kernels
**************************************************************************************************/
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python);
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay);
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale);
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale);
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;
FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
template <typename T, typename FULL_T>
struct AdamCapturableFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*)
const float beta1, const float beta2, const int *step,
const int bias_correction, const float epsilon,
const float *lr, adamMode_t mode, const float decay,
const float *inv_scale) {
if (*noop_gmem == 1) return;
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (bias_correction == 1) {
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
}
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;
FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
g[i] = static_cast<T>(r_g[ii]);
r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (*lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (*lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = static_cast<T>(r_p[ii]);
m[i] = static_cast<T>(r_m[ii]);
v[i] = static_cast<T>(r_v[ii]);
}
}
}
}
};
template <typename T, typename FULL_T>
struct AdamCapturableMasterFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<5> &tl, // NOLINT(*)
const float beta1, const float beta2, const int *step,
const int bias_correction, const float epsilon,
const float *lr, adamMode_t mode, const float decay,
const float *inv_scale) {
if (*noop_gmem == 1) return;
float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (bias_correction == 1) {
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
}
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;
FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;
FULL_T *p_master = reinterpret_cast<FULL_T *>(tl.addresses[4][tensor_loc]);
p_master += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
g[i] = static_cast<T>(r_g[ii]);
r_p[ii] = static_cast<MATH_T>(p_master[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (*lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (*lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = static_cast<T>(r_p[ii]);
p_master[i] = static_cast<FULL_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) {
break;
}
}
if (requires_64bit_indexing) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
} else {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
}
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
AT_CUDA_CHECK(cudaGetLastError());
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) {
typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
}
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<1> &tl, // NOLINT(*),
float *output, float *output_per_tensor,
bool per_tensor, int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = reinterpret_cast<x_t *>(tl.addresses[0][tensor_loc]);
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] += next * next;
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
template <typename x_t>
struct UnscaleL2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<1> &tl, // NOLINT(*),
const float *inv_scale, float *output,
float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = reinterpret_cast<x_t *>(tl.addresses[0][tensor_loc]);
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]) * (*inv_scale);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]) * (*inv_scale);
vals[ii] += next * next;
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
// Probably better to template, but since we are not likely to support other norm
template <typename x_t>
struct MaxNormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<1> &tl, // NOLINT(*),
float *output, float *output_per_tensor,
bool per_tensor, int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = reinterpret_cast<x_t *>(tl.addresses[0][tensor_loc]);
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
__global__ void cleanup(float *output, float *output_per_tensor, float *ret, float *ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor) {
__shared__ float vals[512];
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(final);
}
if (per_tensor) {
float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
float *ret_per_tensor, bool per_tensor, int max_chunks_per_tensor,
int norm_type, float alpha, float beta) {
__shared__ float vals[512];
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
} else {
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
}
}
if (per_tensor) {
float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;
if (norm_type == 0) {
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;
} else {
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] =
sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);
}
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) {
typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
}
template <typename in_t, typename out_t>
struct ScaleFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<2> &tl, // NOLINT(*)
float scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
in_t *in = reinterpret_cast<in_t *>(tl.addresses[0][tensor_loc]);
in += chunk_idx * chunk_size;
out_t *out = reinterpret_cast<out_t *>(tl.addresses[1][tensor_loc]);
out += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_in, in, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
// store
load_store(out, r_out, i_start, 0);
}
} else {
// Non-divergent exit condition for __syncthreads, not necessary here
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite) *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
};
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template <int N, typename T_grad, typename T_weight>
struct SGDFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem,
TensorListMetadata<N>& tl, // NOLINT(*)
float wd, float momentum,
float dampening, float lr, bool nesterov,
bool first_run, bool wd_after_momentum, float scale) {
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad* grad_in = reinterpret_cast<T_grad*>(tl.addresses[0][tensor_loc]);
grad_in += chunk_idx * chunk_size;
T_weight* weight_in = reinterpret_cast<T_weight*>(tl.addresses[1][tensor_loc]);
weight_in += chunk_idx * chunk_size;
T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]);
mom_in += chunk_idx * chunk_size;
at::Half* model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx * chunk_size;
}
n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
// apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum) incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f) {
if (!first_run)
incoming_moms[ii] =
incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum) incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4) model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
}
}
}
}
};
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4) {
for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
}
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half &&
num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && // NOLINT(*)
num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Float && // NOLINT(*)
num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && // NOLINT(*)
num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
} else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}
...@@ -119,6 +119,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -119,6 +119,25 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices, m.def("thd_get_partitioned_indices", &thd_get_partitioned_indices,
"Generate partitioned indices for inputs in THD format"); "Generate partitioned indices for inputs in THD format");
// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling");
m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
// Data structures // Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>()) .def(py::init<>())
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl,
U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory =
(contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Double: { \
using scalar_t_in = double; \
switch (TYPEOUT) { \
case at::ScalarType::Double: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused optimizers and multi-tensor kernels."""
from .fused_adam import FusedAdam
from .fused_sgd import FusedSGD
from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused Adam optimizer."""
import torch
import transformer_engine_extensions as tex
from .multi_tensor_apply import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Currently GPU-only.
This version of fused Adam implements 2 fusions.
* Fusion of the Adam update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to
all the model's parameters into one or a few kernel launches.
:class:`te.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adam_w_mode=False``::
opt = te.optimizers.FusedAdam(model.parameters(), lr = ....)
...
opt.step()
:class:`te.optimizers.FusedAdam` may be used with or without Amp.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False)
master_weights (bool, optional): whether to maintain FP32 master weights
in the optimizer with FP16 mixed precision training, currently can
only be used with capturable set to True. (default: False)
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
adam_w_mode=True,
weight_decay=0.0,
amsgrad=False,
set_grad_none=True,
capturable=False,
master_weights=False,
):
if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
if master_weights and not capturable:
raise RuntimeError(
"Master weights is currently only supported with the capturable version."
)
# If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self.capturable = capturable
self.master_weights = master_weights
# Create full precision master weights
self.param_groups_master = []
for _, pg in enumerate(self.param_groups):
param_list = pg["params"]
self.param_groups_master.append(
{
"params": [
p.clone().detach().float() if self.master_weights else None
for p in param_list
],
}
)
if capturable:
for idx, group in enumerate(self.param_groups):
if len(group["params"]) == 0:
continue
device = group["params"][0].device
for item in ["lr"]:
self.param_groups[idx][item] = group[item].to(device=device)
self._step_supports_amp_scaling = True
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
self.multi_tensor_adam = tex.multi_tensor_adam
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = (
tex.multi_tensor_adam_capturable_master
)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group["params"]:
p.grad = None
else:
super().zero_grad()
def step(self, closure=None, grad_scaler=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grad_scaler (torch.cuda.amp.GradScaler, optional):
gradient scaler (default: None)
"""
loss = None
if closure is not None:
loss = closure()
for group, group_master in zip(self.param_groups, self.param_groups_master):
if len(group["params"]) == 0:
continue
device = group["params"][0].device
bias_correction = 1 if group["bias_correction"] else 0
beta1, beta2 = group["betas"]
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if "step" in group:
group["step"] += (
1
if not self.capturable
else (self._dummy_overflow_buf != 1).to(torch.int)
)
else:
group["step"] = (
1
if not self.capturable
else torch.tensor([1], dtype=torch.int, device=device)
)
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_bf, p_bf, m_bf, v_bf = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
p_16_master = []
p_32_master = []
for p, p_master in zip(group["params"], group_master["params"]):
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError(
"FusedAdam does not support sparse gradients."
)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data).float()
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data).float()
if p.dtype == torch.float16:
if self.master_weights:
p_16_master.append(p_master.data)
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state["exp_avg"])
v_16.append(state["exp_avg_sq"])
elif p.dtype == torch.bfloat16:
g_bf.append(p.grad)
p_bf.append(p)
m_bf.append(state["exp_avg"])
v_bf.append(state["exp_avg_sq"])
elif p.dtype == torch.float32:
if self.master_weights:
p_32_master.append(p_master.data)
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state["exp_avg"])
v_32.append(state["exp_avg_sq"])
else:
raise RuntimeError("FusedAdam only support fp16 and fp32.")
# If the optimizer is capturable, then if there's a grad scaler it works
# on the GPU + a different multi_tensor_applier should be called
if self.capturable:
# overflow check of gradients
found_inf = (
grad_scaler._check_inf_per_device(self)[device]
if grad_scaler is not None
else torch.zeros((1,), device=device)
)
self._dummy_overflow_buf.copy_(found_inf)
# get unscale scale factor
scale, inv_scale = None, None
if grad_scaler:
scale = grad_scaler._get_scale_async()
inv_scale = scale.double().reciprocal().float()
else:
scale = torch.ones((1,), device=device)
inv_scale = torch.ones((1,), device=device)
if len(g_16) > 0:
multi_tensor_applier(
(
self.multi_tensor_adam_capturable_master
if self.master_weights
else self.multi_tensor_adam_capturable
),
self._dummy_overflow_buf,
(
[g_16, p_16, m_16, v_16, p_16_master]
if self.master_weights
else [g_16, p_16, m_16, v_16]
),
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
)
if len(g_bf) > 0:
multi_tensor_applier(
self.multi_tensor_adam_capturable,
self._dummy_overflow_buf,
[g_bf, p_bf, m_bf, v_bf],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
)
if len(g_32) > 0:
multi_tensor_applier(
(
self.multi_tensor_adam_capturable_master
if self.master_weights
else self.multi_tensor_adam_capturable
),
self._dummy_overflow_buf,
(
[g_32, p_32, m_32, v_32, p_32_master]
if self.master_weights
else [g_32, p_32, m_32, v_32]
),
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
)
else:
if len(g_16) > 0:
multi_tensor_applier(
self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
)
if len(g_bf) > 0:
multi_tensor_applier(
self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_bf, p_bf, m_bf, v_bf],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
)
if len(g_32) > 0:
multi_tensor_applier(
self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
)
return loss
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused SGD optimizer."""
import torch
from torch.optim.optimizer import Optimizer, required
import transformer_engine_extensions as tex
from .multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only.
This version of fused SGD implements 2 fusions.
* Fusion of the SGD update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to
all the model's parameters into one or a few kernel launches.
:class:`te.optimizers.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``::
opt = te.optimizers.FusedSGD(model.parameters(), lr = ....)
...
opt.step()
:class:`te.optimizers.FusedSGD` may be used with or without Amp.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(
self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True,
set_grad_none=False,
):
if lr is not required and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
# Skip buffer
self._dummy_overflow_buf = torch.tensor(
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device
)
self.multi_tensor_sgd = tex.multi_tensor_sgd
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group["params"]:
p.grad = None
else:
super().zero_grad()
def get_momentums(self, params):
"""Get momentum buffers of parameters. Create if needed.
Arguments:
params (List): List of parameters.
"""
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if "momentum_buffer" not in param_state:
first_run = True
buf = param_state["momentum_buffer"] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state["momentum_buffer"])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = hasattr(self, "_amp_stash") and hasattr(
self._amp_stash, "fp32_from_fp16_groups"
)
for gid, group in enumerate(self.param_groups):
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
stash = self._amp_stash
fp32_params = [
p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None
]
fp32_grads = [
p.grad
for p in stash.fp32_from_fp32_groups[gid]
if p.grad is not None
]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [
p
for i, p in enumerate(stash.fp16_groups[gid])
if stash.fp32_from_fp16_groups[gid][i].grad is not None
]
fp32_from_fp16_grads = [
p.grad
for p in stash.fp32_from_fp16_groups[gid]
if p.grad is not None
]
fp32_from_fp16_params = [
p
for p in stash.fp32_from_fp16_groups[gid]
if p.grad is not None
]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(
fp32_from_fp16_params
)
fp16_set = [
fp32_from_fp16_grads,
fp32_from_fp16_params,
fp32_from_fp16_momentums,
fp16_model_params,
]
else:
fp16_model_params = [
p for p in stash.fp16_groups[gid] if p.grad is not None
]
fp16_model_grads = [
p.grad for p in stash.fp16_groups[gid] if p.grad is not None
]
fp32_from_fp16_params = [
p
for i, p in enumerate(stash.fp32_from_fp16_groups[gid])
if stash.fp16_groups[gid][i].grad is not None
]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(
fp32_from_fp16_params
)
fp16_set = [
fp16_model_grads,
fp32_from_fp16_params,
fp32_from_fp16_momentums,
fp16_model_params,
]
launch_sets = [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [
p
for p in group["params"]
if (p.dtype == torch.float16 and p.grad is not None)
]
fp16_grads = [
p.grad
for p in group["params"]
if (p.dtype == torch.float16 and p.grad is not None)
]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [
p
for p in group["params"]
if (p.dtype == torch.float32 and p.grad is not None)
]
fp32_grads = [
p.grad
for p in group["params"]
if (p.dtype == torch.float32 and p.grad is not None)
]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [
[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums],
]
for _, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
group["lr"],
nesterov,
first_run,
self.wd_after_momentum,
1.0 / self.most_recent_scale,
)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Multi-tensor apply entry."""
class MultiTensorApply: # pylint: disable=too-few-public-methods
"""Multi-tensor apply entry."""
def __init__(self, chunk_size):
self.chunk_size = chunk_size
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
multi_tensor_applier = MultiTensorApply(2048 * 32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment