Commit 33bafa16 authored by lishen's avatar lishen
Browse files

lowlatency combine实现3级流水

parent 61bc0aff
......@@ -636,6 +636,7 @@ combine(void* combined_x,
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto num_warps = num_threads / kWarpSize;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
// Data type staffs
......@@ -656,7 +657,7 @@ combine(void* combined_x,
constexpr int kNumDivisions = kHidden / QUANTIZATION_GROUPSIZE;
constexpr int kNumMetaBytes = kNumDivisions * sizeof(__hip_bfloat162); // 用于记录数据的最大最小值
constexpr int kNumSendLogFMTBytes = kNumMsgInt4ElemPerWarp * sizeof(int4);
constexpr int kNumStages = 1; // 使用kNumStages>1,则需要的LDS大于64KB
constexpr int kNumStages = 3; // 使用kNumStages>1,则需要的LDS大于64KB
constexpr int kLogFMTShmemSize = kMaxNumWarps * (kNumStages * kNumSendLogFMTBytes + kNumMetaBytes);
__shared__ uint8_t smem_buffer[kLogFMTShmemSize];
/////////////////////////////////////////////
......@@ -707,6 +708,17 @@ combine(void* combined_x,
auto logfmt_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * kNumSendLogFMTBytes); });
// 存储logfmt的最大最小值
auto meta_buffers = bSupportLogFMT ? reinterpret_cast<__hip_bfloat162*>(smem_ptr + kNumStages * kNumSendLogFMTBytes) : nullptr;
// 用于多buffer时临时存储
auto get_num_logfmt_bytes = [&](const int& offset_int4) {
return min(kNumSendLogFMTBytes, static_cast<int>((hidden_bf16_int4 - offset_int4) * sizeof(int4)));
};
// 简化从global到LDS的存储写法
auto logfmt_load_global2lds = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) {
UNROLLED_WARP_COPY_LL(1, lane_id, num_bytes / sizeof(int4),
reinterpret_cast<int4 *>(logfmt_buffers[stage_idx]),
reinterpret_cast<const int4 *>(gmem_ptr),
ld_direct_global, st_na_global);
};
// Unpack layout
int offset, num_tokens_to_send;
......@@ -728,66 +740,103 @@ combine(void* combined_x,
int num_send_bytes = hidden * sizeof(hip_bfloat16);
if (not zero_copy or dst_p2p_ptr != 0) {
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr): reinterpret_cast<int4*>(dst_p2p_ptr);
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr) : reinterpret_cast<int4*>(dst_p2p_ptr);
// 设置数据的真实偏移量
constexpr int kNumIters = hidden_bf16_int4 / kNumMsgInt4ElemPerWarp;
EP_STATIC_ASSERT(kNumIters >= 1, "hidden length too small");
if constexpr (bSupportLogFMT) {
// ===== LogFMT 路径:使用 LDS + encode + 多级流水 =====
int logfmt_offset_bytes = kNumMetaBytes;
// 进入循环,逐步拷贝数据
constexpr int encode_num_warps = hidden_bf16_int4 / kNumMsgInt4ElemPerWarp;
for (int iter_idx = 0; iter_idx < encode_num_warps; ++iter_idx) {
int num_logfmt_bytes = kNumMsgInt4ElemPerWarp * sizeof(int4);
// meta_buffers 存储的thread间隔
constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4;
// 记录S1~S3的编码字节数
int encoded_bytes[kNumStages];
// 原始数据的warp级编译
int warp_offset = iter_idx * kNumMsgInt4ElemPerWarp;
// Prefetch: iter0执行S1
logfmt_load_global2lds(0, cpy_src_int4_ptr, get_num_logfmt_bytes(0));
syncwarp();
if constexpr(bSupportLogFMT) {
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const int& stage_idx = iter_idx % kNumStages;
// Prefetch: iter0执行S2, iter1执行S1
if (kNumStages > 2 && kNumIters > 1) {
int warp_offset = /*1 * */kNumMsgInt4ElemPerWarp;
logfmt_load_global2lds(1, cpy_src_int4_ptr + warp_offset, get_num_logfmt_bytes(warp_offset));
// thread偏移
int thread_offset = /*0 + */lane_id * kNumSendUnrolls;
int num_bytes = logfmt_encode<kNumSendUnrolls>(
logfmt_buffers[0],
(thread_offset % kNumInt4PerDivision == 0) ? meta_buffers + thread_offset / kNumInt4PerDivision : nullptr,
lane_id
);
encoded_bytes[0] = num_bytes;
}
syncwarp();
// 采用3级流水
for (int iter_idx = 0; iter_idx < kNumIters; ++iter_idx) {
// 流水线S1: 加载第 (kNumStages-1) 轮之后的数据
const int stage_last_iter = iter_idx + kNumStages - 1; // 当前iter所在stage中的最后一个,初始为S3的读取数据
if (stage_last_iter < kNumIters) {
int stage_idx = stage_last_iter % kNumStages;
int warp_offset = stage_last_iter * kNumMsgInt4ElemPerWarp;
logfmt_load_global2lds(stage_idx, cpy_src_int4_ptr + warp_offset, get_num_logfmt_bytes(warp_offset));
}
// 流水线S2: 处理下一轮的数据量化
const int stage_next_iter = iter_idx + 1;
if (stage_next_iter < kNumIters) {
int stage_idx = stage_next_iter % kNumStages;
int warp_offset = stage_next_iter * kNumMsgInt4ElemPerWarp;
int thread_offset = warp_offset + lane_id * kNumSendUnrolls;
constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4; // = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes = logfmt_encode<kNumSendUnrolls>(
cpy_src_int4_ptr + warp_offset, // 等同于 x_int4
int num_bytes = logfmt_encode<kNumSendUnrolls>(
logfmt_buffers[stage_idx],
// NOTES: only the leader lane will write the result
(thread_offset % kNumInt4PerDivision == 0) ? meta_buffers + thread_offset / kNumInt4PerDivision : nullptr,
lane_id);
// 将量化后的数据写入
using vec_type = uint32_t;
UNROLLED_WARP_COPY_LL(2, lane_id, num_logfmt_bytes / sizeof(vec_type),
reinterpret_cast<vec_type *>(reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + logfmt_offset_bytes),
reinterpret_cast<vec_type *>(logfmt_buffers[stage_idx]),
ld_nc_global, st_na_global);
lane_id
);
encoded_bytes[stage_idx] = num_bytes;
}
// 起始地址偏移
logfmt_offset_bytes += num_logfmt_bytes;
} else {
// 非量化数据的传输
UNROLLED_WARP_COPY_LL(2, lane_id, kNumMsgInt4ElemPerWarp,
reinterpret_cast<int4*>(cpy_dst_int4_ptr + warp_offset),
reinterpret_cast<const int4*>(cpy_src_int4_ptr + warp_offset),
ld_nc_global, st_na_global);
// 流水线S3:当前轮进行数据拷贝到通信显存
if (iter_idx < kNumIters) {
int stage_idx = iter_idx % kNumStages;
using vec_type = uint64_t;
int nvecs = encoded_bytes[stage_idx] / sizeof(vec_type);
if (nvecs > 0) {
UNROLLED_WARP_COPY_LL(1, lane_id, nvecs,
reinterpret_cast<vec_type*>(reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + logfmt_offset_bytes),
reinterpret_cast<vec_type*>(logfmt_buffers[stage_idx]),
ld_direct_global, st_na_global);
}
logfmt_offset_bytes += encoded_bytes[stage_idx];
}
syncwarp();
}
// Store metadata (min/max values) for LogFMT
if constexpr (bSupportLogFMT) {
// 最终设置节点间传输的字节数
num_send_bytes = logfmt_offset_bytes;
using vec_type = uint32_t;
auto meta_buffers_ptr = reinterpret_cast<vec_type*>(meta_buffers);
auto cpy_dst_uint32_ptr = reinterpret_cast<vec_type*>(cpy_dst_int4_ptr);
// Store metadata
using meta_vec_type = uint32_t;
UNROLLED_WARP_COPY_LL(1, lane_id, kNumMetaBytes / sizeof(meta_vec_type),
reinterpret_cast<meta_vec_type*>(cpy_dst_int4_ptr),
reinterpret_cast<meta_vec_type*>(meta_buffers),
ld_direct_global, st_na_global);
for(int j = lane_id; j < kNumMetaBytes / sizeof(vec_type); j+=kWarpSize) {
*(cpy_dst_uint32_ptr + j) = meta_buffers_ptr[j];
} else {
// ===== 非 LogFMT 路径:直接 global -> global,不经过 LDS =====
for (int iter_idx = 0; iter_idx < kNumIters; ++iter_idx) {
int warp_offset = iter_idx * kNumMsgInt4ElemPerWarp;
UNROLLED_WARP_COPY_LL(kNumSendUnrolls, lane_id, kNumMsgInt4ElemPerWarp,
cpy_dst_int4_ptr + warp_offset,
cpy_src_int4_ptr + warp_offset,
ld_direct_global, st_na_global);
syncwarp();
}
// 非 LogFMT 时,发送字节数为原始大小
num_send_bytes = hidden_bf16_int4 * sizeof(int4); // 或根据实际计算
}
syncwarp();
}
......@@ -858,10 +907,6 @@ LOW_LATENCY_COMBINE_RECV:
// 计算需要多少个warp
constexpr int num_decode_warps = hidden_bf16_int4 / (kNumRecvUnrolls * kWarpSize);
// 限制thread_id
if (warp_id >= num_decode_warps) {
return;
}
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
constexpr int kNumDivisionBytes = kNumDivisions * sizeof(float);
......@@ -889,6 +934,7 @@ LOW_LATENCY_COMBINE_RECV:
topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
}
for (int w_i = warp_id; w_i < num_decode_warps; w_i += num_warps) {
float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
......@@ -906,7 +952,7 @@ LOW_LATENCY_COMBINE_RECV:
const uint8_t* data_buffer = rdma_buffer_type + kNumMetaBytes;
// 读取max/min数据
if(warp_id == 0) {
if(w_i == 0) {
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin<kNumDivisions / (kWarpSize / 16), kNumSendUnrolls, kNumRecvUnrolls>(
......@@ -920,13 +966,13 @@ LOW_LATENCY_COMBINE_RECV:
__syncthreads();
// 获取cast_info_buffers
const auto& info = cast_info_buffers[stage_idx][warp_id];
const auto& info = cast_info_buffers[stage_idx][w_i];
bool enable_cast = info & 1;
int num_casted_prefix = info >> 1; // 可用的
// 计算偏移(与TMA版本逻辑一致)
int warp_offset = kNumLogFMTPerWarpBytes * num_casted_prefix +
kNumBF16PerWarpBytes * (warp_id - num_casted_prefix);
kNumBF16PerWarpBytes * (w_i - num_casted_prefix);
int lane_offset = (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / kWarpSize * lane_id;
// 使用临时缓冲区进行归约
......@@ -940,7 +986,7 @@ LOW_LATENCY_COMBINE_RECV:
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int log_amaxmin_per_warp = kNumRecvUnrolls * kWarpSize * sizeof(int4) / sizeof(hip_bfloat16) / QUANTIZATION_GROUPSIZE;
int division_idx = warp_id * log_amaxmin_per_warp + lane_id * log_amaxmin_per_warp / kWarpSize;
int division_idx = w_i * log_amaxmin_per_warp + lane_id * log_amaxmin_per_warp / kWarpSize;
// 反量化
decode_and_accumulate<kNumRecvUnrolls>(
......@@ -955,7 +1001,7 @@ LOW_LATENCY_COMBINE_RECV:
const uint8_t* data_buffer = rdma_buffer_type;
// 计算偏移
int warp_offset = kNumBF16PerWarpBytes * warp_id;
int warp_offset = kNumBF16PerWarpBytes * w_i;
int lane_offset = kNumBF16PerWarpBytes / kWarpSize * lane_id;
// 使用临时缓冲区进行归约
const uint8_t* thread_data_ptr = data_buffer + warp_offset + lane_offset;
......@@ -984,7 +1030,8 @@ LOW_LATENCY_COMBINE_RECV:
for(int j = 0; j < kNumRecvUnrolls; ++ j) {
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 +
warp_id * kWarpSize * kNumRecvUnrolls)[lane_id * kNumRecvUnrolls + j] = combined_int4[j];
w_i * kWarpSize * kNumRecvUnrolls)[lane_id * kNumRecvUnrolls + j] = combined_int4[j];
}
}
}
}
......@@ -1001,7 +1048,7 @@ void combine(void* combined_x,
bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
constexpr int kMaxNumWarps = 8;
constexpr int kNumMaxTopk = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = kMaxNumWarps / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
......
......@@ -17,7 +17,7 @@ namespace internode_ll {
template <int kNumSendUnrolls>
__forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4* dst_buffer, __hip_bfloat162* shared_amaxmin, const int& lane_id) {
__forceinline__ __device__ int logfmt_encode(int4* lds_buffer, __hip_bfloat162* shared_amaxmin, const int& lane_id) {
EP_STATIC_ASSERT(kNumSendUnrolls == 2, "kNumSendUnrolls == 2 only");
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(__hip_bfloat16); // 8
......@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
const auto& bf162_values = reinterpret_cast<__hip_bfloat162*>(int4_values);
// Calculate lane offset
const auto& ld_buffer = cpy_src_int4_ptr + lane_id * kNumSendUnrolls;
const auto& ld_buffer = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(lds_buffer) + lane_id * kSendValueBytes);
const auto& st_buffer = reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(lds_buffer) + lane_id * kSendValueBytes * 10 / 16);
// Local log amax
auto bf162_amax = __hip_bfloat162(HIPRT_ZERO_BF16, HIPRT_ZERO_BF16);
......@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
const auto& fp162_max = __bfloat1622float2(bf162_amax);
auto amax = __builtin_fmaxf(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
auto amin = __builtin_fminf(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
......@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
if (shared_amaxmin != nullptr) {
*shared_amaxmin = __hip_bfloat162(amax, amin);
}
syncwarp();
// Calculate log amin/amax float
const auto& log_amax = __builtin_log2f(amax);
const auto& log_amin = __builtin_fmaxf(__builtin_log2f(amin), log_amax - kMinClip);
const auto& log_amax = __builtin_amdgcn_logf(amax);
const auto& log_amin = __builtin_fmaxf(__builtin_amdgcn_logf(amin), log_amax - kMinClip);
// 在组内广播enable_cast结果
const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// Case into LogFMT-10 if satisfied
if (enable_cast) {
constexpr int dst_buffer_step = kSendValueBytes * 10 / 16;
const auto& st_buffer = reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * dst_buffer_step);
uint32_t st_u32_values[dst_buffer_step / sizeof(uint32_t)]; // = 5
// 计算10bit数据的两个相邻数值的差值
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
// 计算舍入值
const auto rounding = 2.0f - __builtin_log2f((1.0f + __builtin_exp2f(step)) * 0.5f) * step_inv;
const auto rounding = 2.0f - __builtin_amdgcn_logf((1.0f + __builtin_amdgcn_exp2f(step)) * 0.5f) * step_inv;
const auto fused_rounding = rounding - log_amin * step_inv;
// 用于存储编码后的值
......@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
#pragma unroll
for (int k = 0; k < kNumElemsPerInt4; ++k) { // 8
// 将 bfloat162 转换为 float2
const auto& fp162_fvalue = __bfloat1622float2(bf162_values[k]);
const auto& fp322_fvalue = __bfloat1622float2(bf162_values[k]);
/*
实际进行压缩的公式为:
......@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded[k * 2 + 0] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.y) * step_inv + fused_rounding, 0));
encoded[k * 2 + 0] = __float2uint_rd(__builtin_fmaxf(__builtin_amdgcn_logf(fp322_fvalue.x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(__builtin_fmaxf(__builtin_amdgcn_logf(fp322_fvalue.y) * step_inv + fused_rounding, 0));
}
// 批量打包编码后的值到 st_buffer
st_u32_values[0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);
st_u32_values[1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);
st_u32_values[2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);
st_u32_values[3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_u32_values[4] = (encoded[14] >> 2) | (encoded[15] << 7) | (local_signs << 16);
}
// 保存160bit的数据到st_buffer
st_buffer[0] = st_u32_values[0];
*(reinterpret_cast<int4*>(st_buffer + 1)) = *(reinterpret_cast<int4*>(st_u32_values + 1));
} else {
// 准备收发数据
using vec_type = int4;
const auto& ld_buffer_vec = reinterpret_cast<const vec_type*>(ld_buffer);
auto st_buffer_vec = reinterpret_cast<vec_type*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * kSendValueBytes);
constexpr int kLoopIter = kSendValueBytes / sizeof(vec_type);
#pragma unroll
for (int k = 0; k < kLoopIter; ++k) {
st_buffer_vec[k] = ld_nc_global(ld_buffer_vec + k);
st_buffer[0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);
st_buffer[1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);
st_buffer[2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);
st_buffer[3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_buffer[4] = (encoded[14] >> 2) | (encoded[15] << 7) | (local_signs << 16);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp();
// 计算量化成功和失败时的数据量
constexpr int unable_cast_num_bytes = kWarpSize * kSendValueBytes; // = 64*2*16 = 2048
constexpr int enable_cast_num_bytes = unable_cast_num_bytes * 10 / 16; // = 2048/16*10=1280
......@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
for (int i = 0; i < kNumQuantGroupsPerWarp; ++i) { // sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto amax = static_cast<float>(bf162_amaxmin[i].x);
auto amin = static_cast<float>(bf162_amaxmin[i].y);
log_amax[i] = __builtin_log2f(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : __builtin_fmaxf(__builtin_log2f(amin), log_amax[i] - kMinClip);
log_amax[i] = __builtin_amdgcn_logf(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : __builtin_fmaxf(__builtin_amdgcn_logf(amin), log_amax[i] - kMinClip);
enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
}
......@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
auto decode = [=](const uint32_t& encoded, const uint32_t& sign) {
const auto decoded = encoded == 0 ? .0f : __builtin_exp2f((encoded - 1) * step + log_amin);
const auto decoded = encoded == 0 ? .0f : __builtin_amdgcn_exp2f((encoded - 1) * step + log_amin);
return sign ? -decoded : decoded;
};
......
......@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
return *reinterpret_cast<dtype_t *>(&ret);
}
template <typename dtype_t> __device__ __forceinline__ dtype_t ld_direct_global(const dtype_t *ptr) {
using T = typename VecInt<sizeof(dtype_t)>::vec_t;
auto ret = *(reinterpret_cast<const T *>(ptr));
return *reinterpret_cast<dtype_t *>(&ret);
}
////////////////// used in ibgda
__device__ __forceinline__ void st_na_relaxed(const uint8_t *ptr, uint8_t val) {
uint8_t *non_const_ptr = const_cast<uint8_t *>(ptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment