Unverified Commit 4edcff57 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[PyTorch] Support dtype casting in fused adam (#977)



* support dtype casting fusion in FusedAdam
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix lint
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* changes based on review comments
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* remove unused code
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* code refactor
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix typo
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* refactor
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* remove unused code
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Copy CUDA headers for framework sdists
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 30407856
......@@ -14,7 +14,7 @@ import sys
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
@functools.lru_cache(maxsize=None)
......@@ -254,12 +254,39 @@ def get_frameworks() -> List[str]:
return _frameworks
def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
def copy_common_headers(
src_dir: Union[Path, str],
dst_dir: Union[Path, str],
) -> None:
"""Copy headers from core library
src_dir should be the transformer_engine directory within the root
Transformer Engine repository. All .h and .cuh files within
transformer_engine/common are copied into dst_dir. Relative paths
are preserved.
"""
# Find common header files in src dir
headers = glob.glob(
os.path.join(str(src_dir), "common", "**", "*.h"),
recursive=True,
)
headers.extend(
glob.glob(
os.path.join(str(src_dir), "common", "**", "*.cuh"),
recursive=True,
)
)
headers = [Path(path) for path in headers]
# Copy common header files to dst dir
src_dir = Path(src_dir)
dst_dir = Path(dst_dir)
for path in headers:
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)
def install_and_import(package):
......
......@@ -10,6 +10,13 @@ import torch
from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
class TestFusedOptimizer(unittest.TestCase):
......@@ -169,6 +176,83 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param)
@unittest.skipIf(not is_bf16_compatible(), "bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True
)
@unittest.skipIf(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16
with fp8_model_init(enabled=True):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True
)
class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
......@@ -345,8 +429,9 @@ class AdamTest(unittest.TestCase):
if m.__class__ in [torch.nn.Conv2d]:
m.half()
params_ = [p for p in self.model_.parameters() if p.requires_grad]
master_weights = [p.float() for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(
params_, lr=self.lr, capturable=True, master_weights=True
params_, lr=self.lr, capturable=True, master_weights=master_weights
)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
......
......@@ -423,12 +423,19 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python);
using transformer_engine::DType;
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay);
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype);
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
......
......@@ -8,16 +8,19 @@
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_fp8.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "common/utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
#define THREADS_PER_WARP 32
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
......@@ -25,6 +28,156 @@ typedef enum {
} adamMode_t;
using MATH_T = float;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using transformer_engine::DType;
template <typename T>
struct is_fp8 : std::false_type {};
template <>
struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
template <bool is_fp8>
struct FP8Data {
float scale;
float *amax_ptr;
float *scale_inv_ptr;
float max;
int warp_id;
};
template <>
struct FP8Data<false> {};
template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctorMaster {
static constexpr bool is_fp8_type = is_fp8<PARAM_T>::value;
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<5, is_fp8_type> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
FP8Data<is_fp8_type> fp8_data;
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];
GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;
FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;
FULL_T *p_master = reinterpret_cast<FULL_T *>(tl.addresses[4][tensor_loc]);
p_master += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
if constexpr (is_fp8_type) {
float *scale_ptr = reinterpret_cast<float *>(tl.fp8_meta_addresses[0][tensor_loc]);
fp8_data.scale = scale_ptr != nullptr ? *scale_ptr : 1;
fp8_data.amax_ptr = reinterpret_cast<float *>(tl.fp8_meta_addresses[1][tensor_loc]);
fp8_data.scale_inv_ptr = reinterpret_cast<float *>(tl.fp8_meta_addresses[2][tensor_loc]);
fp8_data.warp_id = threadIdx.x / THREADS_PER_WARP;
fp8_data.max = 0;
}
// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = static_cast<MATH_T>(p_master[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_master[i] = static_cast<FULL_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
if constexpr (is_fp8_type) {
__builtin_assume(fp8_data.max >= 0);
fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max);
p[i] = static_cast<PARAM_T>(r_p[ii] * fp8_data.scale);
} else {
p[i] = static_cast<PARAM_T>(r_p[ii]);
}
}
}
}
if constexpr (is_fp8_type) {
fp8_data.max = transformer_engine::reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(
fp8_data.max, fp8_data.warp_id);
if (threadIdx.x == 0) {
if (fp8_data.amax_ptr != nullptr) {
transformer_engine::atomicMaxFloat(fp8_data.amax_ptr, fp8_data.max);
}
if (fp8_data.scale_inv_ptr != nullptr) {
*fp8_data.scale_inv_ptr = __frcp_rn(fp8_data.scale);
}
}
}
}
};
template <typename T, typename FULL_T, typename index_t>
struct AdamFunctor {
......@@ -338,22 +491,114 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) {
if (tl_size == 4) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
} else {
// g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
}
} else {
if (tl_size == 4) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
p_in_type, 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
} else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
}
}
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) {
break;
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
auto tl_size = tensor_lists.size();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);));
}
AT_CUDA_CHECK(cudaGetLastError());
}
......
......@@ -191,6 +191,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling",
......
......@@ -12,38 +12,55 @@
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include "common/common.h"
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
template <int n, bool USE_FP8 = false>
struct TensorListMetadataBase {
void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
template <int n, bool USE_FP8 = false>
struct TensorListMetadata : public TensorListMetadataBase<n, USE_FP8> {};
template <int n>
struct TensorListMetadata<n, true> : public TensorListMetadataBase<n, true> {
void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]];
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl,
U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however it likes.
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists.size() == depth + 3,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
}
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) { // No range-based for because I need indices
for (int l = 0; l < depth; l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
......@@ -58,9 +75,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
}
}
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0,
"Size mismatch among tensor lists");
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
TensorListMetadata<depth, USE_FP8> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -72,12 +94,15 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++)
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr();
}
loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
......@@ -87,7 +112,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
......@@ -100,7 +124,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
tl.start_tensor_this_launch = t + 1;
} else {
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
for (int d = 0; d < depth; d++) {
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
}
if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++) {
tl.fp8_meta_addresses[i][0] = tl.fp8_meta_addresses[i][loc_tensor_info - 1];
}
}
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
......
......@@ -8,6 +8,7 @@ from transformer_engine_torch import (
multi_tensor_l2norm,
multi_tensor_unscale_l2norm,
multi_tensor_adam,
multi_tensor_adam_fp8,
multi_tensor_adam_capturable,
multi_tensor_adam_capturable_master,
multi_tensor_sgd,
......
......@@ -5,9 +5,27 @@
"""Fused Adam optimizer."""
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from .multi_tensor_apply import multi_tensor_applier
def get_fp8_meta(fp8_tensor):
"""FP8 metadata getter."""
if fp8_tensor._fp8_meta is None:
raise RuntimeError("FP8 meta data is not initialized.")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=fp8_tensor._fp8_meta_forward,
)
fp8_meta_index = fp8_tensor._fp8_meta_index
scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
scale_inv = fp8_tensor._scale_inv
return scale, amax, scale_inv
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
......@@ -50,9 +68,11 @@ class FusedAdam(torch.optim.Optimizer):
method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False)
master_weights (bool, optional): whether to maintain FP32 master weights
in the optimizer with FP16 mixed precision training, currently can
only be used with capturable set to True. (default: False)
master_weights (list of torch.Tensor, optional): master weights to use
for mixed precision training. If provided, the optimizer will update
the master weights and then cast the master weights to the model weights.
If not provided, the optimizer will update the model weights directly.
(default: None)
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -72,15 +92,12 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad=False,
set_grad_none=True,
capturable=False,
master_weights=False,
master_weights=None,
):
if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
if master_weights and not capturable:
raise RuntimeError(
"Master weights is currently only supported with the capturable version."
)
# If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = dict(
......@@ -95,20 +112,10 @@ class FusedAdam(torch.optim.Optimizer):
self.set_grad_none = set_grad_none
self.capturable = capturable
self.master_weights = master_weights
# Create full precision master weights
self.param_groups_master = []
for _, pg in enumerate(self.param_groups):
param_list = pg["params"]
self.param_groups_master.append(
{
"params": [
p.clone().detach().float() if self.master_weights else None
for p in param_list
],
}
)
if master_weights is not None:
assert isinstance(master_weights, list), "master_weights must be a list if provided"
self.master_weights = master_weights
if capturable:
for idx, group in enumerate(self.param_groups):
......@@ -123,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
self.multi_tensor_adam = tex.multi_tensor_adam
self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
......@@ -147,7 +155,9 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None:
loss = closure()
for group, group_master in zip(self.param_groups, self.param_groups_master):
master_param_idx = 0
for group in self.param_groups:
if len(group["params"]) == 0:
continue
device = group["params"][0].device
......@@ -166,51 +176,131 @@ class FusedAdam(torch.optim.Optimizer):
)
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_bf, p_bf, m_bf, v_bf = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
p_16_master = []
p_32_master = []
for p, p_master in zip(group["params"], group_master["params"]):
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError("FusedAdam does not support sparse gradients.")
p_main_of_fp8_model = []
p_main_of_f16_model = []
g_of_fp8_model = []
g_of_f16_model = []
g_of_f32_model = []
m_of_fp8_model = []
m_of_f16_model = []
m_of_f32_model = []
v_of_fp8_model = []
v_of_f16_model = []
v_of_f32_model = []
p_fp8_model = []
p_f16_model = []
p_f32_model = []
# fp8 meta
scales = []
amaxes = []
scale_invs = []
# Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is.
out_dtype = tex.DType.kFloat32
has_fp16 = False
has_bf16 = False
for p in group["params"]:
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data).float()
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data).float()
# Master weights
if self.master_weights and p.dtype != torch.float32:
# model weights can be fp32/bf16/fp16/fp8
# If it's fp32, it has no corresponding master weights
state["master_param"] = self.master_weights[master_param_idx]
master_param_idx += 1
assert (
state["master_param"].shape == p.shape
), "Master weights shape must match model weights shape"
else:
state["master_param"] = None
p_master = state["master_param"]
p_grad = p.grad
if self.master_weights and p_master is not None and p_master.grad is not None:
p_grad = p_master.grad
if p_grad is None:
continue
if p_grad.data.is_sparse:
raise RuntimeError("FusedAdam does not support sparse gradients.")
if p.dtype == torch.float16:
if isinstance(p, Float8Tensor):
out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data)
scale, amax, scale_inv = get_fp8_meta(p)
scales.append(scale)
amaxes.append(amax)
scale_invs.append(scale_inv)
if self.master_weights:
p_16_master.append(p_master.data)
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state["exp_avg"])
v_16.append(state["exp_avg_sq"])
elif p.dtype == torch.bfloat16:
g_bf.append(p.grad)
p_bf.append(p)
m_bf.append(state["exp_avg"])
v_bf.append(state["exp_avg_sq"])
elif p.dtype == torch.float32:
p_main_of_fp8_model.append(p_master.data)
g_of_fp8_model.append(p_grad.data)
m_of_fp8_model.append(state["exp_avg"])
v_of_fp8_model.append(state["exp_avg_sq"])
elif p.dtype in [torch.float16, torch.bfloat16]:
has_fp16 = has_fp16 or p.dtype == torch.float16
has_bf16 = has_bf16 or p.dtype == torch.bfloat16
p_f16_model.append(p.data)
if self.master_weights:
p_32_master.append(p_master.data)
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state["exp_avg"])
v_32.append(state["exp_avg_sq"])
p_main_of_f16_model.append(p_master.data)
g_of_f16_model.append(p_grad.data)
m_of_f16_model.append(state["exp_avg"])
v_of_f16_model.append(state["exp_avg_sq"])
elif p.dtype == torch.float32:
p_f32_model.append(p.data)
g_of_f32_model.append(p_grad.data)
m_of_f32_model.append(state["exp_avg"])
v_of_f32_model.append(state["exp_avg_sq"])
else:
raise RuntimeError("FusedAdam only support fp16 and fp32.")
raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8")
if self.capturable and len(p_fp8_model) > 0:
raise RuntimeError(
"FusedAdam does not support FP8 model weights with capturable=True."
)
if has_fp16 and has_bf16:
# simple to add support for this, but not needed for now
raise RuntimeError(
"FusedAdam does not support a mix of float16 and bfloat16 model weights."
)
def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=None):
# Closures defined in a loop can have unexpected
# behavior when called outside the loop. However, this
# function is called in the same loop iteration as it
# is defined.
# pylint: disable=cell-var-from-loop
inv_scale_arg = () if inv_scale is None else (inv_scale,)
out_dtype_arg = () if out_dtype is None else (out_dtype,)
multi_tensor_applier(
adam_func,
self._dummy_overflow_buf,
tensor_lists,
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
*inv_scale_arg,
*out_dtype_arg,
)
if self.capturable:
# If the optimizer is capturable, then if there's a grad scaler it works
# on the GPU + a different multi_tensor_applier should be called
if self.capturable:
# overflow check of gradients
found_inf = (
grad_scaler._check_inf_per_device(self)[device]
......@@ -228,113 +318,76 @@ class FusedAdam(torch.optim.Optimizer):
scale = torch.ones((1,), device=device)
inv_scale = torch.ones((1,), device=device)
if len(g_16) > 0:
multi_tensor_applier(
(
self.multi_tensor_adam_capturable_master
if self.master_weights
else self.multi_tensor_adam_capturable
),
self._dummy_overflow_buf,
(
[g_16, p_16, m_16, v_16, p_16_master]
if self.master_weights
else [g_16, p_16, m_16, v_16]
),
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
)
if len(g_bf) > 0:
multi_tensor_applier(
self.multi_tensor_adam_capturable,
self._dummy_overflow_buf,
[g_bf, p_bf, m_bf, v_bf],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
if self.master_weights:
if len(p_f16_model) > 0:
tensor_lists = [
g_of_f16_model,
p_f16_model,
m_of_f16_model,
v_of_f16_model,
p_main_of_f16_model,
]
apply_multi_tensor_adam(
self.multi_tensor_adam_capturable_master, tensor_lists, inv_scale
)
if len(g_32) > 0:
multi_tensor_applier(
(
self.multi_tensor_adam_capturable_master
if self.master_weights
else self.multi_tensor_adam_capturable
),
self._dummy_overflow_buf,
(
[g_32, p_32, m_32, v_32, p_32_master]
if self.master_weights
else [g_32, p_32, m_32, v_32]
),
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
if len(p_f32_model) > 0:
tensor_lists = [
g_of_f32_model,
p_f32_model,
m_of_f32_model,
v_of_f32_model,
]
apply_multi_tensor_adam(
self.multi_tensor_adam_capturable, tensor_lists, inv_scale
)
else:
if len(g_16) > 0:
multi_tensor_applier(
self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
if len(p_f16_model) > 0:
tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model]
apply_multi_tensor_adam(
self.multi_tensor_adam_capturable, tensor_lists, inv_scale
)
if len(g_bf) > 0:
multi_tensor_applier(
self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_bf, p_bf, m_bf, v_bf],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
if len(p_f32_model) > 0:
tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model]
apply_multi_tensor_adam(
self.multi_tensor_adam_capturable, tensor_lists, inv_scale
)
if len(g_32) > 0:
multi_tensor_applier(
self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
)
elif self.master_weights: # and self.capturable=False
if len(p_f16_model) > 0:
tensor_lists = [
g_of_f16_model,
p_f16_model,
m_of_f16_model,
v_of_f16_model,
p_main_of_f16_model,
]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
if len(p_fp8_model) > 0:
tensor_lists = [
g_of_fp8_model,
p_fp8_model,
m_of_fp8_model,
v_of_fp8_model,
p_main_of_fp8_model,
scales,
amaxes,
scale_invs,
]
apply_multi_tensor_adam(self.multi_tensor_adam_fp8, tensor_lists, out_dtype)
if len(p_f32_model) > 0:
tensor_lists = [
g_of_f32_model,
p_f32_model,
m_of_f32_model,
v_of_f32_model,
]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
else: # self.master_weights=False and self.capturable=False
if len(p_f16_model) > 0:
tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
if len(p_f32_model) > 0:
tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment