Unverified Commit 4edcff57 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[PyTorch] Support dtype casting in fused adam (#977)



* support dtype casting fusion in FusedAdam
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix lint
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* changes based on review comments
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* remove unused code
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* code refactor
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix typo
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* refactor
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* remove unused code
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Copy CUDA headers for framework sdists
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 30407856
...@@ -14,7 +14,7 @@ import sys ...@@ -14,7 +14,7 @@ import sys
import importlib import importlib
from pathlib import Path from pathlib import Path
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -254,12 +254,39 @@ def get_frameworks() -> List[str]: ...@@ -254,12 +254,39 @@ def get_frameworks() -> List[str]:
return _frameworks return _frameworks
def copy_common_headers(te_src, dst): def copy_common_headers(
headers = te_src / "common" src_dir: Union[Path, str],
for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True): dst_dir: Union[Path, str],
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :]) ) -> None:
Path(new_path).parent.mkdir(exist_ok=True, parents=True) """Copy headers from core library
shutil.copy(file_path, new_path)
src_dir should be the transformer_engine directory within the root
Transformer Engine repository. All .h and .cuh files within
transformer_engine/common are copied into dst_dir. Relative paths
are preserved.
"""
# Find common header files in src dir
headers = glob.glob(
os.path.join(str(src_dir), "common", "**", "*.h"),
recursive=True,
)
headers.extend(
glob.glob(
os.path.join(str(src_dir), "common", "**", "*.cuh"),
recursive=True,
)
)
headers = [Path(path) for path in headers]
# Copy common header files to dst dir
src_dir = Path(src_dir)
dst_dir = Path(dst_dir)
for path in headers:
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)
def install_and_import(package): def install_and_import(package):
......
...@@ -10,6 +10,13 @@ import torch ...@@ -10,6 +10,13 @@ import torch
from torch import nn from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
class TestFusedOptimizer(unittest.TestCase): class TestFusedOptimizer(unittest.TestCase):
...@@ -169,6 +176,83 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -169,6 +176,83 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param) torch.testing.assert_close(ref_param, tst_param)
@unittest.skipIf(not is_bf16_compatible(), "bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-3, atol=1e-3, equal_nan=True
)
@unittest.skipIf(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_weight_cast(self):
dtype = torch.bfloat16
with fp8_model_init(enabled=True):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
ref_optim.step()
tst_optim.step()
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True
)
class TestFusedSGD(TestFusedOptimizer): class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -345,8 +429,9 @@ class AdamTest(unittest.TestCase): ...@@ -345,8 +429,9 @@ class AdamTest(unittest.TestCase):
if m.__class__ in [torch.nn.Conv2d]: if m.__class__ in [torch.nn.Conv2d]:
m.half() m.half()
params_ = [p for p in self.model_.parameters() if p.requires_grad] params_ = [p for p in self.model_.parameters() if p.requires_grad]
master_weights = [p.float() for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam( optimizer_ = te.optimizers.FusedAdam(
params_, lr=self.lr, capturable=True, master_weights=True params_, lr=self.lr, capturable=True, master_weights=master_weights
) )
scaler = torch.cuda.amp.GradScaler(enabled=True) scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True) scaler_ = torch.cuda.amp.GradScaler(enabled=True)
......
...@@ -423,12 +423,19 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( ...@@ -423,12 +423,19 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python); at::Tensor inv_scale, at::optional<bool> per_tensor_python);
using transformer_engine::DType;
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay); const float weight_decay);
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype);
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2, at::Tensor lr, const float beta1, const float beta2,
......
...@@ -8,16 +8,19 @@ ...@@ -8,16 +8,19 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <cuda_fp8.h>
// Another possibility: // Another possibility:
// #include <torch/all.h> // #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include "common/utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h" #include "type_shim.h"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
#define THREADS_PER_WARP 32
typedef enum { typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode ADAM_MODE_0 = 0, // L2 regularization mode
...@@ -25,6 +28,156 @@ typedef enum { ...@@ -25,6 +28,156 @@ typedef enum {
} adamMode_t; } adamMode_t;
using MATH_T = float; using MATH_T = float;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using transformer_engine::DType;
template <typename T>
struct is_fp8 : std::false_type {};
template <>
struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
template <bool is_fp8>
struct FP8Data {
float scale;
float *amax_ptr;
float *scale_inv_ptr;
float max;
int warp_id;
};
template <>
struct FP8Data<false> {};
template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctorMaster {
static constexpr bool is_fp8_type = is_fp8<PARAM_T>::value;
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<5, is_fp8_type> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
FP8Data<is_fp8_type> fp8_data;
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];
GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;
FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;
FULL_T *p_master = reinterpret_cast<FULL_T *>(tl.addresses[4][tensor_loc]);
p_master += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
if constexpr (is_fp8_type) {
float *scale_ptr = reinterpret_cast<float *>(tl.fp8_meta_addresses[0][tensor_loc]);
fp8_data.scale = scale_ptr != nullptr ? *scale_ptr : 1;
fp8_data.amax_ptr = reinterpret_cast<float *>(tl.fp8_meta_addresses[1][tensor_loc]);
fp8_data.scale_inv_ptr = reinterpret_cast<float *>(tl.fp8_meta_addresses[2][tensor_loc]);
fp8_data.warp_id = threadIdx.x / THREADS_PER_WARP;
fp8_data.max = 0;
}
// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = static_cast<MATH_T>(p_master[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_master[i] = static_cast<FULL_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
if constexpr (is_fp8_type) {
__builtin_assume(fp8_data.max >= 0);
fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max);
p[i] = static_cast<PARAM_T>(r_p[ii] * fp8_data.scale);
} else {
p[i] = static_cast<PARAM_T>(r_p[ii]);
}
}
}
}
if constexpr (is_fp8_type) {
fp8_data.max = transformer_engine::reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(
fp8_data.max, fp8_data.warp_id);
if (threadIdx.x == 0) {
if (fp8_data.amax_ptr != nullptr) {
transformer_engine::atomicMaxFloat(fp8_data.amax_ptr, fp8_data.max);
}
if (fp8_data.scale_inv_ptr != nullptr) {
*fp8_data.scale_inv_ptr = __frcp_rn(fp8_data.scale);
}
}
}
}
};
template <typename T, typename FULL_T, typename index_t> template <typename T, typename FULL_T, typename index_t>
struct AdamFunctor { struct AdamFunctor {
...@@ -338,22 +491,114 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -338,22 +491,114 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
if (tl_size == 4) {
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam", p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2, AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);) weight_decay);)
} else { } else {
// g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
}
} else {
if (tl_size == 4) {
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam", p_in_type, 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2, AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);) weight_decay);)
} else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
}
}
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) {
break;
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
auto tl_size = tensor_lists.size();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);));
} }
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
......
...@@ -191,6 +191,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -191,6 +191,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda, m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer", "Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph " "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling", "support and LR scheduling",
......
...@@ -12,38 +12,55 @@ ...@@ -12,38 +12,55 @@
#include <assert.h> #include <assert.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "common/common.h"
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
template <int n> template <int n, bool USE_FP8 = false>
struct TensorListMetadata { struct TensorListMetadataBase {
void *addresses[n][depth_to_max_tensors[n - 1]]; void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]]; int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch; int start_tensor_this_launch;
}; };
template <int n, bool USE_FP8 = false>
struct TensorListMetadata : public TensorListMetadataBase<n, USE_FP8> {};
template <int n>
struct TensorListMetadata<n, true> : public TensorListMetadataBase<n, true> {
void *fp8_meta_addresses[3][depth_to_max_tensors[n - 1]];
};
template <typename T, typename U, typename... ArgTypes> template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl, __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop_flag, T tl,
U callable, ArgTypes... args) { U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however it likes. // Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(chunk_size, noop_flag, tl, args...); callable(chunk_size, noop_flag, tl, args...);
} }
template <int depth, typename T, typename... ArgTypes> template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag, void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable, const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) { ArgTypes... args) {
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists.size() == depth + 3,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
}
int len0 = tensor_lists[0].size(); int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device(); auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) { // No range-based for because I need indices for (int l = 0; l < depth; l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) { for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
...@@ -58,9 +75,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -58,9 +75,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
} }
} }
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0,
"Size mismatch among tensor lists");
}
int ntensors = tensor_lists[0].size(); int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl; TensorListMetadata<depth, USE_FP8> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
...@@ -72,12 +94,15 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -72,12 +94,15 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++)
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr();
}
loc_tensor_info++; loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk; tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++; loc_block_info++;
...@@ -87,7 +112,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -87,7 +112,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) { if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>( multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...); chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
...@@ -100,7 +124,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -100,7 +124,14 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
tl.start_tensor_this_launch = t + 1; tl.start_tensor_this_launch = t + 1;
} else { } else {
tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; for (int d = 0; d < depth; d++) {
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
}
if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++) {
tl.fp8_meta_addresses[i][0] = tl.fp8_meta_addresses[i][loc_tensor_info - 1];
}
}
loc_tensor_info = 1; loc_tensor_info = 1;
tl.start_tensor_this_launch = t; tl.start_tensor_this_launch = t;
} }
......
...@@ -8,6 +8,7 @@ from transformer_engine_torch import ( ...@@ -8,6 +8,7 @@ from transformer_engine_torch import (
multi_tensor_l2norm, multi_tensor_l2norm,
multi_tensor_unscale_l2norm, multi_tensor_unscale_l2norm,
multi_tensor_adam, multi_tensor_adam,
multi_tensor_adam_fp8,
multi_tensor_adam_capturable, multi_tensor_adam_capturable,
multi_tensor_adam_capturable_master, multi_tensor_adam_capturable_master,
multi_tensor_sgd, multi_tensor_sgd,
......
...@@ -5,9 +5,27 @@ ...@@ -5,9 +5,27 @@
"""Fused Adam optimizer.""" """Fused Adam optimizer."""
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
def get_fp8_meta(fp8_tensor):
"""FP8 metadata getter."""
if fp8_tensor._fp8_meta is None:
raise RuntimeError("FP8 meta data is not initialized.")
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=fp8_tensor._fp8_meta_forward,
)
fp8_meta_index = fp8_tensor._fp8_meta_index
scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
scale_inv = fp8_tensor._scale_inv
return scale, amax, scale_inv
class FusedAdam(torch.optim.Optimizer): class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
...@@ -50,9 +68,11 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -50,9 +68,11 @@ class FusedAdam(torch.optim.Optimizer):
method is called. (default: True) method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False) that can be used with CUDA Graphs. (default: False)
master_weights (bool, optional): whether to maintain FP32 master weights master_weights (list of torch.Tensor, optional): master weights to use
in the optimizer with FP16 mixed precision training, currently can for mixed precision training. If provided, the optimizer will update
only be used with capturable set to True. (default: False) the master weights and then cast the master weights to the model weights.
If not provided, the optimizer will update the model weights directly.
(default: None)
.. _Adam - A Method for Stochastic Optimization: .. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -72,15 +92,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -72,15 +92,12 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad=False, amsgrad=False,
set_grad_none=True, set_grad_none=True,
capturable=False, capturable=False,
master_weights=False, master_weights=None,
): ):
if amsgrad: if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.") raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
if master_weights and not capturable:
raise RuntimeError(
"Master weights is currently only supported with the capturable version."
)
# If the optimizer is capturable then LR should be a tensor (on GPU) # If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = dict( defaults = dict(
...@@ -95,20 +112,10 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -95,20 +112,10 @@ class FusedAdam(torch.optim.Optimizer):
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
self.capturable = capturable self.capturable = capturable
self.master_weights = master_weights
# Create full precision master weights if master_weights is not None:
self.param_groups_master = [] assert isinstance(master_weights, list), "master_weights must be a list if provided"
for _, pg in enumerate(self.param_groups): self.master_weights = master_weights
param_list = pg["params"]
self.param_groups_master.append(
{
"params": [
p.clone().detach().float() if self.master_weights else None
for p in param_list
],
}
)
if capturable: if capturable:
for idx, group in enumerate(self.param_groups): for idx, group in enumerate(self.param_groups):
...@@ -123,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -123,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
self.multi_tensor_adam = tex.multi_tensor_adam self.multi_tensor_adam = tex.multi_tensor_adam
self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
...@@ -147,7 +155,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -147,7 +155,9 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
for group, group_master in zip(self.param_groups, self.param_groups_master): master_param_idx = 0
for group in self.param_groups:
if len(group["params"]) == 0: if len(group["params"]) == 0:
continue continue
device = group["params"][0].device device = group["params"][0].device
...@@ -166,51 +176,131 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -166,51 +176,131 @@ class FusedAdam(torch.optim.Optimizer):
) )
# create lists for multi-tensor apply # create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], [] p_main_of_fp8_model = []
g_bf, p_bf, m_bf, v_bf = [], [], [], [] p_main_of_f16_model = []
g_32, p_32, m_32, v_32 = [], [], [], [] g_of_fp8_model = []
p_16_master = [] g_of_f16_model = []
p_32_master = [] g_of_f32_model = []
m_of_fp8_model = []
for p, p_master in zip(group["params"], group_master["params"]): m_of_f16_model = []
if p.grad is None: m_of_f32_model = []
continue v_of_fp8_model = []
if p.grad.data.is_sparse: v_of_f16_model = []
raise RuntimeError("FusedAdam does not support sparse gradients.") v_of_f32_model = []
p_fp8_model = []
p_f16_model = []
p_f32_model = []
# fp8 meta
scales = []
amaxes = []
scale_invs = []
# Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is.
out_dtype = tex.DType.kFloat32
has_fp16 = False
has_bf16 = False
for p in group["params"]:
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values # Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data).float() state["exp_avg"] = torch.zeros_like(p.data).float()
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data).float() state["exp_avg_sq"] = torch.zeros_like(p.data).float()
# Master weights
if self.master_weights and p.dtype != torch.float32:
# model weights can be fp32/bf16/fp16/fp8
# If it's fp32, it has no corresponding master weights
state["master_param"] = self.master_weights[master_param_idx]
master_param_idx += 1
assert (
state["master_param"].shape == p.shape
), "Master weights shape must match model weights shape"
else:
state["master_param"] = None
p_master = state["master_param"]
p_grad = p.grad
if self.master_weights and p_master is not None and p_master.grad is not None:
p_grad = p_master.grad
if p_grad is None:
continue
if p_grad.data.is_sparse:
raise RuntimeError("FusedAdam does not support sparse gradients.")
if p.dtype == torch.float16: if isinstance(p, Float8Tensor):
out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data)
scale, amax, scale_inv = get_fp8_meta(p)
scales.append(scale)
amaxes.append(amax)
scale_invs.append(scale_inv)
if self.master_weights: if self.master_weights:
p_16_master.append(p_master.data) p_main_of_fp8_model.append(p_master.data)
g_16.append(p.grad.data) g_of_fp8_model.append(p_grad.data)
p_16.append(p.data) m_of_fp8_model.append(state["exp_avg"])
m_16.append(state["exp_avg"]) v_of_fp8_model.append(state["exp_avg_sq"])
v_16.append(state["exp_avg_sq"]) elif p.dtype in [torch.float16, torch.bfloat16]:
elif p.dtype == torch.bfloat16: has_fp16 = has_fp16 or p.dtype == torch.float16
g_bf.append(p.grad) has_bf16 = has_bf16 or p.dtype == torch.bfloat16
p_bf.append(p) p_f16_model.append(p.data)
m_bf.append(state["exp_avg"])
v_bf.append(state["exp_avg_sq"])
elif p.dtype == torch.float32:
if self.master_weights: if self.master_weights:
p_32_master.append(p_master.data) p_main_of_f16_model.append(p_master.data)
g_32.append(p.grad.data) g_of_f16_model.append(p_grad.data)
p_32.append(p.data) m_of_f16_model.append(state["exp_avg"])
m_32.append(state["exp_avg"]) v_of_f16_model.append(state["exp_avg_sq"])
v_32.append(state["exp_avg_sq"]) elif p.dtype == torch.float32:
p_f32_model.append(p.data)
g_of_f32_model.append(p_grad.data)
m_of_f32_model.append(state["exp_avg"])
v_of_f32_model.append(state["exp_avg_sq"])
else: else:
raise RuntimeError("FusedAdam only support fp16 and fp32.") raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8")
if self.capturable and len(p_fp8_model) > 0:
raise RuntimeError(
"FusedAdam does not support FP8 model weights with capturable=True."
)
if has_fp16 and has_bf16:
# simple to add support for this, but not needed for now
raise RuntimeError(
"FusedAdam does not support a mix of float16 and bfloat16 model weights."
)
def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=None):
# Closures defined in a loop can have unexpected
# behavior when called outside the loop. However, this
# function is called in the same loop iteration as it
# is defined.
# pylint: disable=cell-var-from-loop
inv_scale_arg = () if inv_scale is None else (inv_scale,)
out_dtype_arg = () if out_dtype is None else (out_dtype,)
multi_tensor_applier(
adam_func,
self._dummy_overflow_buf,
tensor_lists,
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
*inv_scale_arg,
*out_dtype_arg,
)
if self.capturable:
# If the optimizer is capturable, then if there's a grad scaler it works # If the optimizer is capturable, then if there's a grad scaler it works
# on the GPU + a different multi_tensor_applier should be called # on the GPU + a different multi_tensor_applier should be called
if self.capturable:
# overflow check of gradients # overflow check of gradients
found_inf = ( found_inf = (
grad_scaler._check_inf_per_device(self)[device] grad_scaler._check_inf_per_device(self)[device]
...@@ -228,113 +318,76 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -228,113 +318,76 @@ class FusedAdam(torch.optim.Optimizer):
scale = torch.ones((1,), device=device) scale = torch.ones((1,), device=device)
inv_scale = torch.ones((1,), device=device) inv_scale = torch.ones((1,), device=device)
if len(g_16) > 0: if self.master_weights:
multi_tensor_applier( if len(p_f16_model) > 0:
( tensor_lists = [
self.multi_tensor_adam_capturable_master g_of_f16_model,
if self.master_weights p_f16_model,
else self.multi_tensor_adam_capturable m_of_f16_model,
), v_of_f16_model,
self._dummy_overflow_buf, p_main_of_f16_model,
( ]
[g_16, p_16, m_16, v_16, p_16_master] apply_multi_tensor_adam(
if self.master_weights self.multi_tensor_adam_capturable_master, tensor_lists, inv_scale
else [g_16, p_16, m_16, v_16]
),
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
)
if len(g_bf) > 0:
multi_tensor_applier(
self.multi_tensor_adam_capturable,
self._dummy_overflow_buf,
[g_bf, p_bf, m_bf, v_bf],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
) )
if len(p_f32_model) > 0:
if len(g_32) > 0: tensor_lists = [
multi_tensor_applier( g_of_f32_model,
( p_f32_model,
self.multi_tensor_adam_capturable_master m_of_f32_model,
if self.master_weights v_of_f32_model,
else self.multi_tensor_adam_capturable ]
), apply_multi_tensor_adam(
self._dummy_overflow_buf, self.multi_tensor_adam_capturable, tensor_lists, inv_scale
(
[g_32, p_32, m_32, v_32, p_32_master]
if self.master_weights
else [g_32, p_32, m_32, v_32]
),
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
inv_scale,
) )
else: else:
if len(g_16) > 0: if len(p_f16_model) > 0:
multi_tensor_applier( tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model]
self.multi_tensor_adam, apply_multi_tensor_adam(
self._dummy_overflow_buf, self.multi_tensor_adam_capturable, tensor_lists, inv_scale
[g_16, p_16, m_16, v_16],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
) )
if len(p_f32_model) > 0:
if len(g_bf) > 0: tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model]
multi_tensor_applier( apply_multi_tensor_adam(
self.multi_tensor_adam, self.multi_tensor_adam_capturable, tensor_lists, inv_scale
self._dummy_overflow_buf,
[g_bf, p_bf, m_bf, v_bf],
group["lr"],
beta1,
beta2,
group["eps"],
group["step"],
self.adam_w_mode,
bias_correction,
group["weight_decay"],
) )
if len(g_32) > 0: elif self.master_weights: # and self.capturable=False
multi_tensor_applier( if len(p_f16_model) > 0:
self.multi_tensor_adam, tensor_lists = [
self._dummy_overflow_buf, g_of_f16_model,
[g_32, p_32, m_32, v_32], p_f16_model,
group["lr"], m_of_f16_model,
beta1, v_of_f16_model,
beta2, p_main_of_f16_model,
group["eps"], ]
group["step"], apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
self.adam_w_mode, if len(p_fp8_model) > 0:
bias_correction, tensor_lists = [
group["weight_decay"], g_of_fp8_model,
) p_fp8_model,
m_of_fp8_model,
v_of_fp8_model,
p_main_of_fp8_model,
scales,
amaxes,
scale_invs,
]
apply_multi_tensor_adam(self.multi_tensor_adam_fp8, tensor_lists, out_dtype)
if len(p_f32_model) > 0:
tensor_lists = [
g_of_f32_model,
p_f32_model,
m_of_f32_model,
v_of_f32_model,
]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
else: # self.master_weights=False and self.capturable=False
if len(p_f16_model) > 0:
tensor_lists = [g_of_f16_model, p_f16_model, m_of_f16_model, v_of_f16_model]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
if len(p_f32_model) > 0:
tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment