Commit c73d2c3c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.1-dev_whl' into v0.9.1-dev

parents e522713c ac61d69b
...@@ -269,9 +269,9 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size] ...@@ -269,9 +269,9 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
[&] { [&] {
using T_ACC = at::acc_type<scalar_t, true>; using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon; T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>(); scalar_t* self_data = input.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>(); scalar_t* out_data = out.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>(); scalar_t* weight_data= weight.expect_contiguous()->data_ptr<scalar_t>();
if (hidden_size<=1024){ if (hidden_size<=1024){
fused_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps); fused_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
} }
...@@ -330,7 +330,7 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size] ...@@ -330,7 +330,7 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){ if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Half,
at::ScalarType::BFloat16, at::ScalarType::BFloat16,
...@@ -339,9 +339,9 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size] ...@@ -339,9 +339,9 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
[&] { [&] {
using T_ACC = at::acc_type<scalar_t, true>; using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon; T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>(); scalar_t* self_data = input.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* other_data =residual.data_ptr<scalar_t>(); scalar_t* other_data = residual.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>(); scalar_t* weight_data= weight.expect_contiguous()->data_ptr<scalar_t>();
if (hidden_size<=1024){ if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps); fused_add_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment