Unverified Commit dd949ace authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "[1/2][resubmit] sgl-kernel: Fuse routed scaling factor into m… (#9035)

parent f2887498
...@@ -132,7 +132,6 @@ class TopK(CustomOp): ...@@ -132,7 +132,6 @@ class TopK(CustomOp):
scoring_func: str = "softmax", scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
# NOTE: scoring_func is not used for now, but we keep it for future use # NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details # see https://github.com/sgl-project/sglang/pull/4505 for more details
...@@ -148,9 +147,6 @@ class TopK(CustomOp): ...@@ -148,9 +147,6 @@ class TopK(CustomOp):
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.apply_routed_scaling_factor_on_output = (
apply_routed_scaling_factor_on_output
)
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
...@@ -211,7 +207,6 @@ class TopK(CustomOp): ...@@ -211,7 +207,6 @@ class TopK(CustomOp):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=self.apply_routed_scaling_factor_on_output,
) )
def forward_cpu( def forward_cpu(
...@@ -381,7 +376,6 @@ def grouped_topk_gpu( ...@@ -381,7 +376,6 @@ def grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -429,8 +423,6 @@ def grouped_topk_gpu( ...@@ -429,8 +423,6 @@ def grouped_topk_gpu(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True) else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
) )
topk_weights = topk_weights / topk_weights_sum topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
...@@ -479,7 +471,6 @@ def biased_grouped_topk_impl( ...@@ -479,7 +471,6 @@ def biased_grouped_topk_impl(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -531,8 +522,6 @@ def biased_grouped_topk_impl( ...@@ -531,8 +522,6 @@ def biased_grouped_topk_impl(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True) else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
) )
topk_weights = topk_weights / topk_weights_sum topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
...@@ -575,10 +564,7 @@ def biased_grouped_topk_gpu( ...@@ -575,10 +564,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
# TODO(trevor-m): Remove once sgl-kernel is updated
assert not apply_routed_scaling_factor_on_output
assert ( assert (
routed_scaling_factor is not None routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk" ), "routed_scaling_factor is required for biased_grouped_topk"
...@@ -597,8 +583,6 @@ def biased_grouped_topk_gpu( ...@@ -597,8 +583,6 @@ def biased_grouped_topk_gpu(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
# TODO(trevor-m): Uncomment once sgl-kernel is updated
# apply_routed_scaling_factor_on_output,
) )
# TODO merge into kernel # TODO merge into kernel
if (expert_location_dispatch_info is not None) or ( if (expert_location_dispatch_info is not None) or (
...@@ -609,7 +593,6 @@ def biased_grouped_topk_gpu( ...@@ -609,7 +593,6 @@ def biased_grouped_topk_gpu(
) )
return topk_weights, topk_ids return topk_weights, topk_ids
elif _use_aiter: elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0] token = gating_output.shape[0]
device = gating_output.device device = gating_output.device
assert ( assert (
...@@ -641,7 +624,6 @@ def biased_grouped_topk_gpu( ...@@ -641,7 +624,6 @@ def biased_grouped_topk_gpu(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
...@@ -701,7 +683,6 @@ def select_experts( ...@@ -701,7 +683,6 @@ def select_experts(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
) -> TopKOutput: ) -> TopKOutput:
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
...@@ -727,7 +708,6 @@ def select_experts( ...@@ -727,7 +708,6 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
else: else:
topk_weights, topk_ids = biased_grouped_topk( topk_weights, topk_ids = biased_grouped_topk(
...@@ -742,14 +722,12 @@ def select_experts( ...@@ -742,14 +722,12 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
elif torch_native and custom_routing_function is None: elif torch_native and custom_routing_function is None:
assert ( assert (
num_token_non_padded is None num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native" ), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = fused_topk_native( topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -757,7 +735,6 @@ def select_experts( ...@@ -757,7 +735,6 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
) )
elif custom_routing_function is None: elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
# Qwen3MOE uses fused_topk # Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk( topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -772,7 +749,6 @@ def select_experts( ...@@ -772,7 +749,6 @@ def select_experts(
num_token_non_padded is None num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function" ), "num_token_non_padded is not yet supported in custom_routing_function"
assert expert_location_dispatch_info is None assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = custom_routing_function( topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
......
...@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -174,7 +174,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " "num_fused_shared_experts, float routed_scaling_factor) -> "
"(Tensor[])"); "(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def( m.def(
......
...@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl( ...@@ -59,7 +59,6 @@ __device__ void moe_fused_gate_impl(
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor,
bool apply_routed_scaling_factor_on_output,
Params params) { Params params) {
int tidx = threadIdx.x; int tidx = threadIdx.x;
int64_t thread_row = int64_t thread_row =
...@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl( ...@@ -249,9 +248,6 @@ __device__ void moe_fused_gate_impl(
for (int ii = 0; ii < topk; ++ii) { for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii; int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = output_ptr[idx] / output_sum; output_ptr[idx] = output_ptr[idx] / output_sum;
if (apply_routed_scaling_factor_on_output) {
output_ptr[idx] *= routed_scaling_factor;
}
} }
} }
} }
...@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel( ...@@ -286,8 +282,7 @@ __global__ void moe_fused_gate_kernel(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor) {
bool apply_routed_scaling_factor_on_output) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params; KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>( moe_fused_gate_impl<T>(
input, input,
...@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel( ...@@ -299,7 +294,6 @@ __global__ void moe_fused_gate_kernel(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params); params);
} }
...@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel( ...@@ -320,8 +314,7 @@ __global__ void moe_fused_gate_kernel(
topk_group, \ topk_group, \
topk, \ topk, \
num_fused_shared_experts, \ num_fused_shared_experts, \
routed_scaling_factor, \ routed_scaling_factor); \
apply_routed_scaling_factor_on_output); \
dispatched = true; \ dispatched = true; \
} while (0) } while (0)
...@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -349,8 +342,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor) {
bool apply_routed_scaling_factor_on_output) {
KernelParamsDynamic params; KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
...@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic( ...@@ -369,7 +361,6 @@ __global__ void moe_fused_gate_kernel_dynamic(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
params); params);
} }
...@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -383,8 +374,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor) {
bool apply_routed_scaling_factor_on_output) {
int64_t num_rows = input.size(0); int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1); int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
...@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -483,8 +473,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group, topk_group,
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor);
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kHalf) { } else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(), input.data_ptr(),
...@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -497,8 +486,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group, topk_group,
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor);
apply_routed_scaling_factor_on_output);
} else if (input.scalar_type() == at::kFloat) { } else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>( moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(), input.data_ptr(),
...@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -511,8 +499,7 @@ std::vector<at::Tensor> moe_fused_gate(
topk_group, topk_group,
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor);
apply_routed_scaling_factor_on_output);
} else { } else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
} }
......
...@@ -243,8 +243,7 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -243,8 +243,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t topk_group, int64_t topk_group,
int64_t topk, int64_t topk,
int64_t num_fused_shared_experts, int64_t num_fused_shared_experts,
double routed_scaling_factor, double routed_scaling_factor);
bool apply_routed_scaling_factor_on_output);
void fp8_blockwise_scaled_grouped_mm( void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output, torch::Tensor& output,
......
...@@ -44,7 +44,6 @@ def moe_fused_gate( ...@@ -44,7 +44,6 @@ def moe_fused_gate(
topk, topk,
num_fused_shared_experts=0, num_fused_shared_experts=0,
routed_scaling_factor=0, routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
): ):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group # it split group of expert into num_expert_group, and use top2 expert weight sum in each group
...@@ -52,13 +51,8 @@ def moe_fused_gate( ...@@ -52,13 +51,8 @@ def moe_fused_gate(
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now. # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be # num_fused_shared_experts: if > 0, the last several experts will be replaced with shared experts
# replaced with shared experts. the shared experts will be divided by the # routed_scaling_factor: if > 0, the shared experts will be scaled by this factor
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default( return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor, input_tensor,
bias, bias,
...@@ -67,7 +61,6 @@ def moe_fused_gate( ...@@ -67,7 +61,6 @@ def moe_fused_gate(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
) )
......
...@@ -19,10 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ...@@ -19,10 +19,7 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
], ],
) )
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False]) def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
def test_moe_fused_gate_combined(
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
num_experts, num_expert_group, topk_group, topk = params num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32 dtype = torch.float32
...@@ -40,7 +37,6 @@ def test_moe_fused_gate_combined( ...@@ -40,7 +37,6 @@ def test_moe_fused_gate_combined(
topk=topk, topk=topk,
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5, routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
ref_output, ref_indices = biased_grouped_topk( ref_output, ref_indices = biased_grouped_topk(
scores, scores,
...@@ -52,7 +48,6 @@ def test_moe_fused_gate_combined( ...@@ -52,7 +48,6 @@ def test_moe_fused_gate_combined(
topk_group=topk_group, topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5, routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment