"docs/vscode:/vscode.git/clone" did not exist on "c6b636f9fbfd0308ee8d883afd8ccd7ef823eb25"
Commit 2dbefd03 authored by zhangshao's avatar zhangshao
Browse files

Update layernorm_kernels.cu

parent 785f450d
......@@ -17,16 +17,7 @@
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
static inline bool get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
if (strcmp(value, "0") == 0) {
return false;
}
return true;
}
return false;
}
static const bool use_old= get_env_("USE_VLLM_OLD_OP");
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
......@@ -418,7 +409,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(!use_old&&hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
......@@ -488,7 +479,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(!use_old&&hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment