Commit a9601800 authored by wenjh's avatar wenjh
Browse files

Fix build error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 5cf21c3b
......@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
}
static void DestroyHipBlasLtHandle(hipblasLtHandle_t handle) {
if(handle != nullptr)
if(handle != nullptr) {
NVTE_CHECK_HIPBLASLT(hipblasLtDestroy(handle));
}
}
......@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
{
HipBlasltUserArgsCache() {}
HipBlasltUserArgsCache(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltUserArgsCache& operator=(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& getBuffer(hipStream_t stream, size_t size, bool host)
{
std::unordered_map<size_t, HipBlasltUserArgsBuffer>& buffers = host ? host_buffers_: device_buffers_;
......@@ -1524,13 +1524,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
DType input_type = inputB[0]->data.dtype;
DType bias_type = bias[0]->data.dtype;
NVTE_CHECK(bias_type == DType::kFloat32 || bias_type == DType::kFloat16 || bias_type == DType::kBFloat16);
for (int i = 0; i < m.size(); ++i) {
void* input_ptr = inputB[i]->data.dptr;
void* bias_ptr = bias[i]->data.dptr;
batch_size = k[i];
output_dim = n[i];
int batch_size = static_cast<int>(k[i]);
int output_dim = static_cast<int>(n[i]);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input_dtype, IType,
input_type, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
bias_type, OType,
detail::bias_gradient_kernelLauncher<IType, OType>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment