Commit 99e60246 authored by wenjh's avatar wenjh
Browse files

Make release_v2.9 compile pass

parent cbb14a5f
...@@ -130,7 +130,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -130,7 +130,7 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
*/ */
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset); NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated) * allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "../util/logging.h" #include "../util/logging.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#ifdef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
constexpr uint32_t WARP_SIZE = 32; constexpr uint32_t WARP_SIZE = 32;
......
...@@ -228,7 +228,7 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp ...@@ -228,7 +228,7 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp
// Allocate memory for swizzled mxfp8 scaling factors // Allocate memory for swizzled mxfp8 scaling factors
const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
at::Tensor swizzled_scale_inv = at::empty( at::Tensor swizzled_scale_inv = at::empty(
std::vector<int64_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); std::vector<int64_t>{static_cast<int64_t>(swizzled_scale_inv_first_dim), static_cast<int64_t>(swizzled_scale_inv_last_dim)}, options);
// Set rowwise scaling factors on output // Set rowwise scaling factors on output
void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{}; NVTEShape swizzled_scale_inv_shape{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment