Commit e1d82bf3 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev] Adjust computation logic to avoid precision loss when casting acc_s from...

[Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16 (#141)

- Remove redundant `acc_s_0` fragment in flash attention kernel
- Simplify memory copy and reduction operations
- Reorder memory copy and scaling steps for improved performance
- Add Hopper-specific synchronization method in CUDA reduce template
- Update reduce operation to use architecture-specific synchronization
parent 3d7b2dc5
...@@ -31,7 +31,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -31,7 +31,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype)
...@@ -57,28 +56,27 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -57,28 +56,27 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s_0) T.clear(acc_s)
T.gemm( T.gemm(
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm( T.gemm(
Q_pe_shared, Q_pe_shared,
K_pe_shared, K_pe_shared,
acc_s_0, acc_s,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol) policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(acc_s_0, S_shared)
T.copy(S_shared, acc_s)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
...@@ -105,7 +103,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -105,7 +103,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype) K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype)
...@@ -131,31 +128,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -131,31 +128,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * bz + k * block_N kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s_0) T.clear(acc_s)
T.gemm( T.gemm(
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm( T.gemm(
Q_pe_shared, Q_pe_shared,
K_pe_shared, K_pe_shared,
acc_s_0, acc_s,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol) policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(acc_s_0, S_shared)
T.copy(S_shared, acc_s)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1) T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
...@@ -301,4 +296,4 @@ if __name__ == "__main__": ...@@ -301,4 +296,4 @@ if __name__ == "__main__":
print("All close") print("All close")
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch") latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
...@@ -161,8 +161,13 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -161,8 +161,13 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
continue; continue;
int reducing_threads = (*extent) * (*scale); int reducing_threads = (*extent) * (*scale);
std::stringstream ss; std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " if (Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
<< reducing_threads << ", " << (*scale) << ">::run"; ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run_hopper";
} else {
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run";
}
Array<PrimExpr> thread_reduce_args = { Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)}; StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) { if (reducing_threads >= 32) {
......
...@@ -33,10 +33,8 @@ template <class Reducer, int threads, int scale> struct AllReduce { ...@@ -33,10 +33,8 @@ template <class Reducer, int threads, int scale> struct AllReduce {
constexpr int offset = threads / 2; constexpr int offset = threads / 2;
if constexpr (offset >= 32) { if constexpr (offset >= 32) {
__syncthreads(); __syncthreads();
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(256));
red_buf[threadIdx.x] = x; red_buf[threadIdx.x] = x;
__syncthreads(); __syncthreads();
// asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(256));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]); x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else { } else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
...@@ -47,6 +45,24 @@ template <class Reducer, int threads, int scale> struct AllReduce { ...@@ -47,6 +45,24 @@ template <class Reducer, int threads, int scale> struct AllReduce {
return AllReduce<Reducer, offset, scale>::run(x, red_buf); return AllReduce<Reducer, offset, scale>::run(x, red_buf);
} }
} }
template <typename T>
static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
constexpr int offset = threads / 2;
if constexpr (offset >= 32) {
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(threads));
red_buf[threadIdx.x] = x;
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(threads));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
}
if constexpr (offset == scale) {
return x;
} else {
return AllReduce<Reducer, offset, scale>::run_hopper(x, red_buf);
}
}
}; };
} // namespace tl } // namespace tl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment