"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "b9ec4909e674c0377ad8d6da210cd819df7f7f5a"
Unverified Commit bda29934 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Handle dtypes more carefully in multi-tensor Adam (#1888)



* Add dtype checks in multi-tensor Adam
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid throwing exception in destructor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0a1499fa
...@@ -225,8 +225,8 @@ struct AdamFunctorMasterParamRemainder { ...@@ -225,8 +225,8 @@ struct AdamFunctorMasterParamRemainder {
r_m[ii] = static_cast<MATH_T>(m[i]); r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]); r_v[ii] = static_cast<MATH_T>(v[i]);
local_p[ii] = static_cast<int16_t>(p[i]); local_p[ii] = p[i];
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]); local_p_rem[ii] = p_remainder[i];
} else { } else {
r_g[ii] = MATH_T(0); r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0); r_m[ii] = MATH_T(0);
...@@ -280,8 +280,8 @@ struct AdamFunctorMasterParamRemainder { ...@@ -280,8 +280,8 @@ struct AdamFunctorMasterParamRemainder {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]); p_remainder[i] = local_p_rem[ii];
p[i] = static_cast<int16_t>(local_p[ii]); p[i] = local_p[ii];
m[i] = static_cast<FULL_T>(r_m[ii]); m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]); v[i] = static_cast<FULL_T>(r_v[ii]);
...@@ -466,8 +466,8 @@ struct AdamCapturableFunctor { ...@@ -466,8 +466,8 @@ struct AdamCapturableFunctor {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
p[i] = static_cast<T>(r_p[ii]); p[i] = static_cast<T>(r_p[ii]);
m[i] = static_cast<T>(r_m[ii]); m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<T>(r_v[ii]); v[i] = static_cast<FULL_T>(r_v[ii]);
} }
} }
} }
...@@ -577,9 +577,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -577,9 +577,6 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) { const float weight_decay, const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -587,16 +584,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -587,16 +584,48 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
size_t max_size = 0; // Check tensor list sizes
// 4 tensor lists: g, p, m, v
// 5 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5,
"Expected 4 or 5 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == p_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(p_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
if (num_tensor_lists == 5) {
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
}
// Check if 64-bit indices are required
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) { if (tensor_lists[i][j]->numel() >= INT_MAX) {
max_size = tensor_lists[i][j]->numel(); requires_64bit_indexing = true;
if (max_size >= INT_MAX) { break;
requires_64bit_indexing = true;
break;
}
} }
} }
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
...@@ -604,16 +633,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -604,16 +633,10 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
} }
} }
const auto g_in_type_te = tensor_lists[0][0]->dtype(); // Launch kernel
const auto p_in_type_te = tensor_lists[1][0]->dtype();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
if (num_tensor_lists == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -637,7 +660,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -637,7 +660,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
} }
} else { } else {
if (num_tensor_lists == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // g, p, m, v
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -647,6 +670,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -647,6 +670,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
stream, beta1, beta2, bias_correction1, bias_correction2, stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -667,8 +691,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -667,8 +691,6 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) { const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -676,23 +698,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -676,23 +698,43 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
const auto g_in_type_te = tensor_lists[0][0]->dtype(); // Check tensor list sizes
const auto p_in_type_te = tensor_lists[1][0]->dtype(); // 5 tensor lists: g, p, m, v, p_remainder
const size_t num_tensor_lists = tensor_lists.size();
// case 5: g, p, m, v, p_master NVTE_CHECK(num_tensor_lists == 5, "Expected 5 tensor lists, but found ", num_tensor_lists);
NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5"); const size_t num_tensors_per_list = tensor_lists[0].size();
NVTE_CHECK(p_in_type_te == DType::kBFloat16, for (size_t i = 1; i < num_tensor_lists; i++) {
"Adam with BF16 param remainders requires BF16 params"); NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// g, p, m, v, p_master // Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == DType::kBFloat16, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(DType::kBFloat16));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kInt16, "Param remainder tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kInt16));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id, AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);); (adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -702,9 +744,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -702,9 +744,6 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype, const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) { const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -712,16 +751,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -712,16 +751,53 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
size_t max_size = 0; // Check tensor list sizes
// 8 tensor lists: g, p_fp8, m, v, p_master, scale, amax, scale_inv
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 8, "Expected 8 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(
tensor_lists[1][j]->dtype() == fp8_dtype || tensor_lists[1][j]->dtype() == DType::kByte,
"Param tensor ", j, " has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(fp8_dtype));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[5][j]->dtype() == DType::kFloat32, "Scale tensor ", j,
" has dtype=", to_string(tensor_lists[5][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[6][j]->dtype() == DType::kFloat32, "Absmax tensor ", j,
" has dtype=", to_string(tensor_lists[6][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[7][j]->dtype() == DType::kFloat32, "Scale-inverse tensor ", j,
" has dtype=", to_string(tensor_lists[7][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
// Check if 64-bit indices are required
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (size_t i = 0; i < num_tensor_lists; i++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) { if (tensor_lists[i][j]->numel() >= INT_MAX) {
max_size = tensor_lists[i][j]->numel(); requires_64bit_indexing = true;
if (max_size >= INT_MAX) { break;
requires_64bit_indexing = true;
break;
}
} }
} }
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
...@@ -729,11 +805,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -729,11 +805,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
} }
} }
const auto g_in_type_te = tensor_lists[0][0]->dtype(); // Launch kernel
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
...@@ -764,6 +836,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, ...@@ -764,6 +836,34 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
Tensor step, const int mode, const int bias_correction, Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale, const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) { const int device_id, cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 4, "Expected 4 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
...@@ -782,6 +882,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, ...@@ -782,6 +882,37 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id, Tensor inv_scale, const int device_id,
cudaStream_t stream) { cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
NVTE_CHECK(num_tensor_lists == 5, "Expected 4 tensor lists, but found ", num_tensor_lists);
const size_t num_tensors_per_list = tensor_lists[0].size();
for (size_t i = 1; i < num_tensor_lists; i++) {
NVTE_CHECK(tensor_lists[i].size() == num_tensors_per_list, "Tensor list ", i,
" has size=", tensor_lists[i].size(), ", but expected size=", num_tensors_per_list);
}
// Check tensor dtypes
const auto g_in_type_te = tensor_lists[0][0]->dtype();
for (size_t j = 0; j < num_tensors_per_list; j++) {
NVTE_CHECK(tensor_lists[0][j]->dtype() == g_in_type_te, "Grad tensor ", j,
" has dtype=", to_string(tensor_lists[0][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[1][j]->dtype() == g_in_type_te, "Param tensor ", j,
" has dtype=", to_string(tensor_lists[1][j]->dtype()),
", but expected dtype=", to_string(g_in_type_te));
NVTE_CHECK(tensor_lists[2][j]->dtype() == DType::kFloat32, "First moment tensor ", j,
" has dtype=", to_string(tensor_lists[2][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[3][j]->dtype() == DType::kFloat32, "Second moment tensor ", j,
" has dtype=", to_string(tensor_lists[3][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
NVTE_CHECK(tensor_lists[4][j]->dtype() == DType::kFloat32, "Master param tensor ", j,
" has dtype=", to_string(tensor_lists[4][j]->dtype()),
", but expected dtype=", to_string(DType::kFloat32));
}
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
......
...@@ -52,7 +52,7 @@ class OptionalCUDAGuard { ...@@ -52,7 +52,7 @@ class OptionalCUDAGuard {
~OptionalCUDAGuard() { ~OptionalCUDAGuard() {
if (device_changed_) { if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_)); cudaSetDevice(prev_device_);
} }
} }
......
...@@ -46,6 +46,8 @@ std::string to_string(const DType type) { ...@@ -46,6 +46,8 @@ std::string to_string(const DType type) {
return "Float8E8M0"; return "Float8E8M0";
case DType::kFloat4E2M1: case DType::kFloat4E2M1:
return "Float4E2M1"; return "Float4E2M1";
case DType::kInt16:
return "Int16";
case DType::kInt32: case DType::kInt32:
return "Int32"; return "Int32";
case DType::kInt64: case DType::kInt64:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment