Commit 7405fe09 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.3'

parents 7462e0e4 c636071d
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <cassert> #include <cassert>
#include <numeric> #include <numeric>
#include "amd_detail/hip_float8.h"
#include "common/common.h" #include "common/common.h"
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
...@@ -330,7 +331,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -330,7 +331,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert(rs_output.element_size() == 2); assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
#ifdef USE_ROCM #ifdef USE_ROCM
reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, reducescatter2_userbuff_fp8<te_hip_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#else #else
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
#endif #endif
......
...@@ -2034,12 +2034,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -2034,12 +2034,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::bf8>>( template void reducescatter2_userbuff_stridedoutput_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event); cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_stridedoutput_fp8<hip_f8<hip_f8_type::fp8>>( template void reducescatter2_userbuff_stridedoutput_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, const int colelements, const int strideelements, communicator *comm, cudaStream_t stream,
cudaEvent_t comm_launch_event); cudaEvent_t comm_launch_event);
...@@ -2052,30 +2052,30 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, ...@@ -2052,30 +2052,30 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler,
comm, stream, comm_launch_event); comm, stream, comm_launch_event);
} }
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::bf8>>(void *output, float *scale, template void reducescatter2_userbuff_fp8<te_hip_fp8_e5m2>(void *output, float *scale,
const int handler, const int offset, const int handler, const int offset,
const int elements, communicator *comm, const int elements, communicator *comm,
cudaStream_t stream, cudaStream_t stream,
cudaEvent_t comm_launch_event); cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_fp8<hip_f8<hip_f8_type::fp8>>(void *output, float *scale, template void reducescatter2_userbuff_fp8<te_hip_fp8_e4m3>(void *output, float *scale,
const int handler, const int offset, const int handler, const int offset,
const int elements, communicator *comm, const int elements, communicator *comm,
cudaStream_t stream, cudaStream_t stream,
cudaEvent_t comm_launch_event); cudaEvent_t comm_launch_event);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::fp8>>( template void reducescatter2_userbuff_strided_atomic_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<hip_f8<hip_f8_type::bf8>>( template void reducescatter2_userbuff_strided_atomic_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::fp8>>( template void reducescatter2_userbuff_strided_multiatomic_fp8<te_hip_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<hip_f8<hip_f8_type::bf8>>( template void reducescatter2_userbuff_strided_multiatomic_fp8<te_hip_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
...@@ -2845,10 +2845,10 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in ...@@ -2845,10 +2845,10 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::fp8>>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<te_hip_fp8_e4m3>(void *inputs, void *output, float *scale,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
template void reduce_fp8_in_bf16_out<hip_f8<hip_f8_type::bf8>>(void *inputs, void *output, float *scale, template void reduce_fp8_in_bf16_out<te_hip_fp8_e5m2>(void *inputs, void *output, float *scale,
int num_inputs, int input_size, int num_inputs, int input_size,
cudaStream_t stream); cudaStream_t stream);
#else #else
......
...@@ -334,8 +334,8 @@ using fp8e4m3 = __nv_fp8_e4m3; ...@@ -334,8 +334,8 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using bf16 = __hip_bfloat16; using bf16 = __hip_bfloat16;
using fp8e4m3 = hip_f8<hip_f8_type::fp8>; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
template <typename T> template <typename T>
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include "../common.h" #include "../common.h"
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
using __nv_fp8_e4m3 = hip_f8<hip_f8_type::fp8>; using __nv_fp8_e4m3 = te_hip_fp8_e4m3;
using __nv_fp8_e5m2 = hip_f8<hip_f8_type::bf8>; using __nv_fp8_e5m2 = te_hip_fp8_e5m2;
#define __ldlu(x) __ldg(x) #define __ldlu(x) __ldg(x)
#endif #endif
......
...@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// broadcast the amax to all threads in a warp from the lane 0 // broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0; constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = __shfl(warp_tile_amax, lane_zero); warp_tile_amax = __shfl(warp_tile_amax, lane_zero, THREADS_PER_WARP);
__syncthreads();
#else #else
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
#endif #endif
...@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
// broadcast the amax to all threads in a warp from the lane 0 // broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0; constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = __shfl(warp_tile_amax, lane_zero); warp_tile_amax = __shfl(warp_tile_amax, lane_zero, THREADS_PER_WARP);
__syncthreads();
#else #else
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
#endif #endif
......
...@@ -986,8 +986,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float ...@@ -986,8 +986,8 @@ __device__ __forceinline__ void reciprocal<float>(float *value_inv, const float
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#else #else
using fp8e4m3 = hip_f8<hip_f8_type::fp8>; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment