Commit b27e513d authored by wenjh's avatar wenjh
Browse files

[ROCM6.3] Fix build on rocm-6.3


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 92d59fe4
......@@ -11,6 +11,7 @@
#include <cassert>
#include <numeric>
#include "amd_detail/hip_float8.h"
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
......@@ -289,7 +290,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
#ifdef USE_ROCM
reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
reducescatter2_userbuff_fp8<te_hip_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#else
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#endif
......
......@@ -2034,12 +2034,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
}
#ifdef __HIP_PLATFORM_AMD__
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::bf8>>(
template void reducescatter2_userbuff_stridedoutput_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::fp8>>(
template void reducescatter2_userbuff_stridedoutput_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event);
......@@ -2052,30 +2052,30 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler,
comm, stream, comm_launch_event);
}
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(void *output, float *scale,
template void reducescatter2_userbuff_fp8<te_hip_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::fp8>>(void *output, float *scale,
template void reducescatter2_userbuff_fp8<te_hip_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset,
const int elements, communicator *comm,
cudaStream_t stream,
cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::fp8>>(
template void reducescatter2_userbuff_strided_atomic_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::bf8>>(
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::fp8>>(
template void reducescatter2_userbuff_strided_multiatomic_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::bf8>>(
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
......@@ -2845,10 +2845,10 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
}
#ifdef __HIP_PLATFORM_AMD__
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::fp8>>(void *inputs, void *output, float *scale,
template void reduce_fp8_in_bf16_out<te_hip_fp8_e4m3>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::bf8>>(void *inputs, void *output, float *scale,
template void reduce_fp8_in_bf16_out<te_hip_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size,
cudaStream_t stream);
#else
......
......@@ -334,8 +334,8 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using bf16 = __hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
template <typename T>
......
......@@ -9,8 +9,8 @@
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
using __nv_fp8_e4m3 = hip_f8<hip_f8_type::fp8>;
using __nv_fp8_e5m2 = hip_f8<hip_f8_type::bf8>;
using __nv_fp8_e4m3 = te_hip_fp8_e4m3;
using __nv_fp8_e5m2 = te_hip_fp8_e5m2;
#define __ldlu(x) __ldg(x)
#endif
......
......@@ -986,8 +986,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
using e8m0_t = uint8_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment