Commit 33bafa16 authored by lishen's avatar lishen
Browse files

lowlatency combine实现3级流水

parent 61bc0aff
This diff is collapsed.
......@@ -17,7 +17,7 @@ namespace internode_ll {
template <int kNumSendUnrolls>
__forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4* dst_buffer, __hip_bfloat162* shared_amaxmin, const int& lane_id) {
__forceinline__ __device__ int logfmt_encode(int4* lds_buffer, __hip_bfloat162* shared_amaxmin, const int& lane_id) {
EP_STATIC_ASSERT(kNumSendUnrolls == 2, "kNumSendUnrolls == 2 only");
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(__hip_bfloat16); // 8
......@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
const auto& bf162_values = reinterpret_cast<__hip_bfloat162*>(int4_values);
// Calculate lane offset
const auto& ld_buffer = cpy_src_int4_ptr + lane_id * kNumSendUnrolls;
const auto& ld_buffer = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(lds_buffer) + lane_id * kSendValueBytes);
const auto& st_buffer = reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(lds_buffer) + lane_id * kSendValueBytes * 10 / 16);
// Local log amax
auto bf162_amax = __hip_bfloat162(HIPRT_ZERO_BF16, HIPRT_ZERO_BF16);
......@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
const auto& fp162_max = __bfloat1622float2(bf162_amax);
auto amax = __builtin_fmaxf(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
auto amin = __builtin_fminf(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
......@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
if (shared_amaxmin != nullptr) {
*shared_amaxmin = __hip_bfloat162(amax, amin);
}
syncwarp();
// Calculate log amin/amax float
const auto& log_amax = __builtin_log2f(amax);
const auto& log_amin = __builtin_fmaxf(__builtin_log2f(amin), log_amax - kMinClip);
const auto& log_amax = __builtin_amdgcn_logf(amax);
const auto& log_amin = __builtin_fmaxf(__builtin_amdgcn_logf(amin), log_amax - kMinClip);
// 在组内广播enable_cast结果
const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// Case into LogFMT-10 if satisfied
if (enable_cast) {
constexpr int dst_buffer_step = kSendValueBytes * 10 / 16;
const auto& st_buffer = reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * dst_buffer_step);
uint32_t st_u32_values[dst_buffer_step / sizeof(uint32_t)]; // = 5
// 计算10bit数据的两个相邻数值的差值
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
// 计算舍入值
const auto rounding = 2.0f - __builtin_log2f((1.0f + __builtin_exp2f(step)) * 0.5f) * step_inv;
const auto rounding = 2.0f - __builtin_amdgcn_logf((1.0f + __builtin_amdgcn_exp2f(step)) * 0.5f) * step_inv;
const auto fused_rounding = rounding - log_amin * step_inv;
// 用于存储编码后的值
......@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
#pragma unroll
for (int k = 0; k < kNumElemsPerInt4; ++k) { // 8
// 将 bfloat162 转换为 float2
const auto& fp162_fvalue = __bfloat1622float2(bf162_values[k]);
const auto& fp322_fvalue = __bfloat1622float2(bf162_values[k]);
/*
实际进行压缩的公式为:
......@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded[k * 2 + 0] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.y) * step_inv + fused_rounding, 0));
encoded[k * 2 + 0] = __float2uint_rd(__builtin_fmaxf(__builtin_amdgcn_logf(fp322_fvalue.x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(__builtin_fmaxf(__builtin_amdgcn_logf(fp322_fvalue.y) * step_inv + fused_rounding, 0));
}
// 批量打包编码后的值到 st_buffer
st_u32_values[0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);
st_u32_values[1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);
st_u32_values[2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);
st_u32_values[3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_u32_values[4] = (encoded[14] >> 2) | (encoded[15] << 7) | (local_signs << 16);
}
// 保存160bit的数据到st_buffer
st_buffer[0] = st_u32_values[0];
*(reinterpret_cast<int4*>(st_buffer + 1)) = *(reinterpret_cast<int4*>(st_u32_values + 1));
} else {
// 准备收发数据
using vec_type = int4;
const auto& ld_buffer_vec = reinterpret_cast<const vec_type*>(ld_buffer);
auto st_buffer_vec = reinterpret_cast<vec_type*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * kSendValueBytes);
constexpr int kLoopIter = kSendValueBytes / sizeof(vec_type);
#pragma unroll
for (int k = 0; k < kLoopIter; ++k) {
st_buffer_vec[k] = ld_nc_global(ld_buffer_vec + k);
st_buffer[0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);
st_buffer[1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);
st_buffer[2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);
st_buffer[3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_buffer[4] = (encoded[14] >> 2) | (encoded[15] << 7) | (local_signs << 16);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp();
// 计算量化成功和失败时的数据量
constexpr int unable_cast_num_bytes = kWarpSize * kSendValueBytes; // = 64*2*16 = 2048
constexpr int enable_cast_num_bytes = unable_cast_num_bytes * 10 / 16; // = 2048/16*10=1280
......@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
for (int i = 0; i < kNumQuantGroupsPerWarp; ++i) { // sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto amax = static_cast<float>(bf162_amaxmin[i].x);
auto amin = static_cast<float>(bf162_amaxmin[i].y);
log_amax[i] = __builtin_log2f(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : __builtin_fmaxf(__builtin_log2f(amin), log_amax[i] - kMinClip);
log_amax[i] = __builtin_amdgcn_logf(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : __builtin_fmaxf(__builtin_amdgcn_logf(amin), log_amax[i] - kMinClip);
enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
}
......@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
auto decode = [=](const uint32_t& encoded, const uint32_t& sign) {
const auto decoded = encoded == 0 ? .0f : __builtin_exp2f((encoded - 1) * step + log_amin);
const auto decoded = encoded == 0 ? .0f : __builtin_amdgcn_exp2f((encoded - 1) * step + log_amin);
return sign ? -decoded : decoded;
};
......
......@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
return *reinterpret_cast<dtype_t *>(&ret);
}
template <typename dtype_t> __device__ __forceinline__ dtype_t ld_direct_global(const dtype_t *ptr) {
using T = typename VecInt<sizeof(dtype_t)>::vec_t;
auto ret = *(reinterpret_cast<const T *>(ptr));
return *reinterpret_cast<dtype_t *>(&ret);
}
////////////////// used in ibgda
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
uint8_t *non_const_ptr = const_cast<uint8_t *>(ptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment