Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
785f450d
Commit
785f450d
authored
Aug 20, 2024
by
zhangshao
Browse files
修复rmsnorm bug,增加USE_VLLM_OLD_OP标志使用原版rmsnorm
parent
1c5e7720
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
23 deletions
+38
-23
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+38
-23
No files found.
csrc/layernorm_kernels.cu
View file @
785f450d
...
@@ -17,7 +17,16 @@
...
@@ -17,7 +17,16 @@
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
#endif
static
inline
bool
get_env_
(
const
char
*
env_var
)
{
if
(
char
*
value
=
std
::
getenv
(
env_var
))
{
if
(
strcmp
(
value
,
"0"
)
==
0
)
{
return
false
;
}
return
true
;
}
return
false
;
}
static
const
bool
use_old
=
get_env_
(
"USE_VLLM_OLD_OP"
);
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
...
@@ -332,7 +341,6 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
...
@@ -332,7 +341,6 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
int
i
=
blockIdx
.
x
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
int
tcol
=
cols
/
Vec
;
if
(
j
>=
tcol
)
return
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
...
@@ -341,22 +349,26 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
...
@@ -341,22 +349,26 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
idx
*=
Vec
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
if
(
j
<
tcol
)
{
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
#pragma unroll
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
__syncthreads
();
trstd
=
s_rstd
;
trstd
=
s_rstd
;
#pragma unroll
if
(
j
<
tcol
)
{
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
#pragma unroll
int
jj
=
j
*
Vec
+
ii
;
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
...
@@ -369,27 +381,30 @@ __global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t*
...
@@ -369,27 +381,30 @@ __global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t*
int
i
=
blockIdx
.
x
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
int
tcol
=
cols
/
Vec
;
if
(
j
>=
tcol
)
return
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
scalar_t
intput_vec
[
Vec
];
T_ACC
trstd
;
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
#pragma unroll
if
(
j
<
tcol
)
{
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
#pragma unroll
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
}
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
__syncthreads
();
trstd
=
s_rstd
;
trstd
=
s_rstd
;
#pragma unroll
if
(
j
<
tcol
)
{
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
#pragma unroll
int
jj
=
j
*
Vec
+
ii
;
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
output
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
*
(
LoadT
*
)(
output
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
...
@@ -403,7 +418,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
...
@@ -403,7 +418,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
if
(
!
use_old
&&
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
BFloat16
,
...
@@ -473,7 +488,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
...
@@ -473,7 +488,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
&&
ptrs_are_aligned
){
if
(
!
use_old
&&
hidden_size
%
16
==
0
&&
hidden_size
>=
2048
&&
hidden_size
<=
8192
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
BFloat16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment