Commit c73d2c3c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.1-dev_whl' into v0.9.1-dev

parents e522713c ac61d69b
......@@ -269,9 +269,9 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
scalar_t* self_data = input.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* out_data = out.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* weight_data= weight.expect_contiguous()->data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
......@@ -330,7 +330,7 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
......@@ -339,9 +339,9 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* other_data =residual.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
scalar_t* self_data = input.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* other_data = residual.expect_contiguous()->data_ptr<scalar_t>();
scalar_t* weight_data= weight.expect_contiguous()->data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment