Unverified Commit e2caf78d authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing `unused-variable` warning at TE/JAX extension compile (#937)



replaced plain C asserts with NVTE_CHECK to avoid unused-variable warnings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 4a4f05da
...@@ -108,7 +108,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -108,7 +108,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
float *scale_inv = reinterpret_cast<float *>(buffers[3]); float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4]; auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]); float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out); NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
...@@ -221,7 +221,8 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -221,7 +221,8 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
void *workspace_ptr = buffers[9]; void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out); NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
scale = nullptr; scale = nullptr;
scale_inv = nullptr; scale_inv = nullptr;
...@@ -291,7 +292,8 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -291,7 +292,8 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
float *amax_out = reinterpret_cast<float *>(buffers[7]); float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out); NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
scale = nullptr; scale = nullptr;
scale_inv = nullptr; scale_inv = nullptr;
......
...@@ -214,7 +214,8 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -214,7 +214,8 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *amax_out = buffers[9]; auto *amax_out = buffers[9];
auto *workspace = buffers[10]; auto *workspace = buffers[10];
auto *barrier = buffers[11]; auto *barrier = buffers[11];
assert(amax_out == amax); NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size; auto batch_size = desc.batch_size;
...@@ -227,7 +228,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -227,7 +228,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
...@@ -263,7 +263,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -263,7 +263,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps; auto eps = desc.eps;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
...@@ -288,7 +287,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -288,7 +287,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dbeta_part_dtype = desc.dbeta_part_dtype; auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0]; auto *ograd = buffers[0];
auto *mu = buffers[1]; auto *mu = buffers[1];
...@@ -321,7 +319,7 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -321,7 +319,7 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *amax_out = buffers[7]; auto *amax_out = buffers[7];
auto *workspace = buffers[8]; auto *workspace = buffers[8];
auto *barrier = buffers[9]; auto *barrier = buffers[9];
assert(amax_out == amax); NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive.");
void *bias = nullptr; void *bias = nullptr;
void *mu = nullptr; void *mu = nullptr;
...@@ -337,7 +335,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -337,7 +335,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
...@@ -371,7 +368,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -371,7 +368,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
......
...@@ -17,7 +17,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -17,7 +17,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto *scale_inv = reinterpret_cast<float *>(buffers[3]); auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4]; auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]); auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out); NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector(); auto shape = desc.shape.to_vector();
......
...@@ -43,7 +43,7 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -43,7 +43,7 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
auto *input_cast = buffers[4]; auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5]; auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]); float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out); NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
...@@ -100,7 +100,8 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -100,7 +100,8 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void *workspace_ptr = buffers[8]; void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out); NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
scale = nullptr; scale = nullptr;
scale_inv = nullptr; scale_inv = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment