Commit 58b43d4a authored by zhanghj2's avatar zhanghj2
Browse files

修改写出

parent d6379e50
......@@ -131,7 +131,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*256); // NOTE We don't use __ldg here since it is incompatible with PDL
datas[i] = *(float4*)(oaccum_ptr + lane_idx*8 + i*4); // NOTE We don't use __ldg here since it is incompatible with PDL
}
// Warp #i accumulates activation for seq #i
{
......@@ -155,7 +155,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
result[i].z += lse_scale * datas[i].z;
result[i].w += lse_scale * datas[i].w;
if (split != my_num_splits-1) {
datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*256);
datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*8 + i*4);
}
}
// }
......@@ -182,7 +182,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
data_converted[2] = (ElementT)(data.z);
data_converted[3] = (ElementT)(data.w);
static_assert(sizeof(ElementT) == 2);
*(uint64_t*)(o_ptr + lane_idx*4 + i*256) = *(uint64_t*)data_converted;
*(uint64_t*)(o_ptr + lane_idx*8 + i*4) = *(uint64_t*)data_converted;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment