Unverified Commit e2caf78d authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing `unused-variable` warning at TE/JAX extension compile (#937)



replaced plain C asserts with NVTE_CHECK to avoid unused-variable warnings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 4a4f05da
......@@ -108,7 +108,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
......@@ -221,7 +221,8 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
......@@ -291,7 +292,8 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
......
......@@ -214,7 +214,8 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
assert(amax_out == amax);
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
......@@ -227,7 +228,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
......@@ -263,7 +263,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
......@@ -288,7 +287,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0];
auto *mu = buffers[1];
......@@ -321,7 +319,7 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
assert(amax_out == amax);
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive.");
void *bias = nullptr;
void *mu = nullptr;
......@@ -337,7 +335,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
......@@ -371,7 +368,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
......
......@@ -17,7 +17,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
......
......@@ -43,7 +43,7 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
......@@ -100,7 +100,8 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment