Commit eac0d49b authored by yuguo's avatar yuguo
Browse files

[DCU] fix batchlinear core dump in 2.5

parent d8041744
......@@ -968,12 +968,12 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
......
......@@ -52,7 +52,7 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm":
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32):
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment