Commit a9601800 authored by wenjh's avatar wenjh
Browse files

Fix build error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 5cf21c3b
...@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) { ...@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
} }
static void DestroyHipBlasLtHandle(hipblasLtHandle_t handle) { static void DestroyHipBlasLtHandle(hipblasLtHandle_t handle) {
if(handle != nullptr) if(handle != nullptr) {
NVTE_CHECK_HIPBLASLT(hipblasLtDestroy(handle)); NVTE_CHECK_HIPBLASLT(hipblasLtDestroy(handle));
} }
} }
...@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache ...@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
{ {
HipBlasltUserArgsCache() {} HipBlasltUserArgsCache() {}
HipBlasltUserArgsCache(const HipBlasltUserArgsCache&) = delete; HipBlasltUserArgsCache(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete; HipBlasltUserArgsCache& operator=(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& getBuffer(hipStream_t stream, size_t size, bool host) HipBlasltUserArgsBuffer& getBuffer(hipStream_t stream, size_t size, bool host)
{ {
std::unordered_map<size_t, HipBlasltUserArgsBuffer>& buffers = host ? host_buffers_: device_buffers_; std::unordered_map<size_t, HipBlasltUserArgsBuffer>& buffers = host ? host_buffers_: device_buffers_;
...@@ -1524,13 +1524,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1524,13 +1524,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
DType input_type = inputB[0]->data.dtype; DType input_type = inputB[0]->data.dtype;
DType bias_type = bias[0]->data.dtype; DType bias_type = bias[0]->data.dtype;
NVTE_CHECK(bias_type == DType::kFloat32 || bias_type == DType::kFloat16 || bias_type == DType::kBFloat16); NVTE_CHECK(bias_type == DType::kFloat32 || bias_type == DType::kFloat16 || bias_type == DType::kBFloat16);
for (int i = 0; i < m.size(); ++i) { for (int i = 0; i < m.size(); ++i) {
void* input_ptr = inputB[i]->data.dptr; void* input_ptr = inputB[i]->data.dptr;
void* bias_ptr = bias[i]->data.dptr; void* bias_ptr = bias[i]->data.dptr;
batch_size = k[i]; int batch_size = static_cast<int>(k[i]);
output_dim = n[i]; int output_dim = static_cast<int>(n[i]);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input_dtype, IType, input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
bias_type, OType, bias_type, OType,
detail::bias_gradient_kernelLauncher<IType, OType>( detail::bias_gradient_kernelLauncher<IType, OType>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment