Commit cbb14a5f authored by wenjh's avatar wenjh
Browse files

Fix build error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent b3dcfc28
......@@ -14,6 +14,7 @@
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
#ifdef __HIP_PLATFORM_AMD__
namespace transformer_engine {
namespace {
constexpr uint32_t WARP_SIZE = 32;
......@@ -311,11 +312,16 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor*
}
} // namespace transformer_engine
#endif
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors);
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK(false, "nvte_swizzle_block_scaling_to_mxfp8_scaling_factors is not supported on rocm");
#else
using namespace transformer_engine;
swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input),
convertNVTETensorCheck(output), stream);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment