Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbefd03
Commit
2dbefd03
authored
Aug 20, 2024
by
zhangshao
Browse files
Update layernorm_kernels.cu
parent
785f450d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
12 deletions
+3
-12
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+3
-12
No files found.
csrc/layernorm_kernels.cu
View file @
2dbefd03
...
@@ -17,16 +17,7 @@
...
@@ -17,16 +17,7 @@
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
#endif
static
inline
bool
get_env_
(
const
char
*
env_var
)
{
if
(
char
*
value
=
std
::
getenv
(
env_var
))
{
if
(
strcmp
(
value
,
"0"
)
==
0
)
{
return
false
;
}
return
true
;
}
return
false
;
}
static
const
bool
use_old
=
get_env_
(
"USE_VLLM_OLD_OP"
);
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
...
@@ -418,7 +409,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
...
@@ -418,7 +409,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
!
use_old
&&
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
BFloat16
,
...
@@ -488,7 +479,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
...
@@ -488,7 +479,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
!
use_old
&&
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
&&
ptrs_are_aligned
){
if
(
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
BFloat16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment