Unverified Commit e3bb24e5 authored by Zhenhuan Liu's avatar Zhenhuan Liu Committed by GitHub
Browse files

[MoE][Pytorch]Fix size mismatch error in fp8 transpose. (#988)



Fix size mismatch error in fp8 transpose.
Signed-off-by: default avatarDennis Liu <denliu@nvidia.com>
parent 56e0b351
......@@ -285,9 +285,9 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) {
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
if (M == 0 || N == 0) return input;
auto output = allocateTorchTensor(input.size(1), input.size(0), DType::kByte);
if (M == 0 || N == 0) return output;
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment