Commit b3a96764 authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop' into mha-train-ldsbypass

parents ece8f9b8 1ab31830
...@@ -113,8 +113,8 @@ __global__ void ...@@ -113,8 +113,8 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -142,8 +142,8 @@ __global__ void ...@@ -142,8 +142,8 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -117,8 +117,8 @@ __global__ void ...@@ -117,8 +117,8 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -148,8 +148,8 @@ __global__ void ...@@ -148,8 +148,8 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_lse_grid == nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -1172,7 +1172,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -1172,7 +1172,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
static_for<0, MXdlPerWave, 1>{}( static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); }); [&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_warp_local_1d_id() < AccM2) if(get_lane_local_1d_id() < AccM2)
{ {
static_for<0, MXdlPerWave, 1>{}([&](auto I) { static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global // copy from VGPR to Global
......
...@@ -1350,7 +1350,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1350,7 +1350,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
static_for<0, MXdlPerWave, 1>{}( static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); }); [&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_warp_local_1d_id() < AccM2) if(get_lane_local_1d_id() < AccM2)
{ {
static_for<0, MXdlPerWave, 1>{}([&](auto I) { static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global // copy from VGPR to Global
......
...@@ -19,6 +19,8 @@ __device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + ...@@ -19,6 +19,8 @@ __device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x +
__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
__device__ index_t get_lane_local_1d_id() { return threadIdx.x % get_warp_size(); }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
__device__ index_t get_grid_size() { return gridDim.x; } __device__ index_t get_grid_size() { return gridDim.x; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment