Unverified Commit 8e2d37e9 authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

[PyTorch] Fix corner case in router fuson (#2009)



* fix bug if all values<0
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* minor fix
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 8dfdb911
......@@ -148,11 +148,21 @@ def run_comparison(
# Set some parameters
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
......@@ -281,11 +291,21 @@ def test_topk_softmax(
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
......@@ -321,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True
......@@ -379,15 +399,12 @@ def profile_topk_softmax(
if __name__ == "__main__":
test_fused_scores_for_aux_loss(
dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax"
test_topk_softmax(
dtype=torch.float32,
num_tokens=1024,
num_experts=128,
topk=4,
use_pre_softmax=False,
group_topk=None,
scaling_factor=None,
)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
......@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id);
__syncwarp();
if (lane_id == 0) {
......@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id);
__syncwarp();
if (lane_id == 0) {
......
......@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
if (score_function == 0) {
if (topk > 1) {
auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id);
auto sum_logits =
warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = static_cast<DataType>(static_cast<double>(local_logits[i]) /
(static_cast<double>(sum_logits) + epsilon));
......@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
*/
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id);
auto sum_fwd_input =
warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i];
}
__syncwarp();
auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id);
auto sum_Output_x_Grad =
warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] =
......
......@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
// score_function == 0 means sigmoid
if (score_function == 0) {
if (topk > 1) {
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id);
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] = static_cast<double>(topk_scores[i]) / (sum_scores + epsilon);
}
......@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */ local_act_from_fwd,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
/*reduce func = */ ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = (local_routing_map[i] ? static_cast<double>(local_grad[i]) *
......@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */ local_comp_buf,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
/*reduce func = */ ReduceFuncType::SUM, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) {
......
......@@ -26,14 +26,28 @@ __device__ inline T sum(T a, T b) {
return a + b;
}
enum ReduceFuncType {
SUM,
MAX,
};
template <typename T>
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T),
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type,
int lane_id) {
T (*reduce_func)(T, T);
double default_val = 0;
if (type == ReduceFuncType::SUM) {
reduce_func = sum;
default_val = 0;
} else if (type == ReduceFuncType::MAX) {
reduce_func = max;
default_val = -std::numeric_limits<double>::infinity();
}
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val =
lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : static_cast<double>(0);
volatile double val = lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : default_val;
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
val = reduce_func(val, data_ptr[i]);
}
......@@ -57,13 +71,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i
template <typename T>
__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size,
T (*reduce_func)(T, T), int lane_id) {
ReduceFuncType type, int lane_id) {
T (*reduce_func)(T, T);
double default_val = 0;
if (type == ReduceFuncType::SUM) {
reduce_func = sum;
default_val = 0;
} else if (type == ReduceFuncType::MAX) {
reduce_func = max;
default_val = -std::numeric_limits<double>::infinity();
}
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val = lane_id < data_size && mask[lane_id]
? static_cast<double>(data_ptr[lane_id])
: static_cast<double>(0);
volatile double val =
lane_id < data_size && mask[lane_id] ? static_cast<double>(data_ptr[lane_id]) : default_val;
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
if (mask[i]) {
val = reduce_func(val, data_ptr[i]);
......@@ -108,7 +131,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
float sum_Output_x_Grad = warp_reduce_on_shmem(
/*data ptr = */ comp_buf,
/*data size = */ data_size,
/*reduce func = */ sum, lane_id);
/*reduce func = */ ReduceFuncType::SUM, lane_id);
// In-place update
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) {
......@@ -127,14 +150,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
template <typename DataType>
__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) {
// 1. compute the max of value
float max_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, max, lane_id));
float max_val =
static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id));
// 2. value -> exp_value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val));
}
__syncwarp();
// 3. compute the sum of exp_value
float sum_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, sum, lane_id));
float sum_val =
static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id));
// 4. update the softmax value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(scores[i]) / sum_val;
......@@ -145,19 +170,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i
template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
T *topk_scores, int lane_id) {
// Check if the index is masked by the later iteration
auto is_masked = [&topk_indices](int k, int index) {
if (k == 0) return false;
for (int i = 0; i < k; i++) {
if (topk_indices[i] == index) return true;
}
return false;
};
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
for (int k = 0; k < topk; k++) {
// Find the max value and its index
volatile double val =
(lane_id < data_size) ? static_cast<double>(scores[lane_id]) : static_cast<double>(0);
volatile double val = (lane_id < data_size && !is_masked(k, lane_id))
? static_cast<double>(scores[lane_id])
: -std::numeric_limits<double>::infinity();
volatile int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
volatile double cur_val = scores[i];
volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits<double>::infinity()
: static_cast<double>(scores[i]);
if (cur_val > val) {
val = cur_val;
index = i;
......@@ -175,17 +210,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
if (lane_id == 0) {
topk_indices[k] = index;
topk_scores[k] = val;
scores[index] =
static_cast<double>(-1.0) - val; // make the selected experts using val = - 1 - val
}
__syncwarp();
}
// Reset the scores to the original value
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
scores[topk_indices[i]] =
static_cast<double>(-1.0) - static_cast<double>(scores[topk_indices[i]]);
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment