Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d8ae62c7
Commit
d8ae62c7
authored
Aug 22, 2024
by
zhangshao
Browse files
Update layernorm_kernels_opt.cu
parent
bf278a88
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+3
-3
No files found.
csrc/opt/layernorm_kernels_opt.cu
View file @
d8ae62c7
...
@@ -338,9 +338,9 @@ __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scal
...
@@ -338,9 +338,9 @@ __global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scal
T_ACC
trstd
;
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
if
(
j
<
tcol
)
{
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
...
@@ -377,8 +377,8 @@ __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t*
...
@@ -377,8 +377,8 @@ __global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t*
T_ACC
trstd
;
T_ACC
trstd
;
int
idx
=
i
*
tcol
+
j
;
int
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
idx
*=
Vec
;
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
if
(
j
<
tcol
)
{
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment