Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
9f8f2c7f
"...models/llama/LlamaDecoderSelfAttentionLayer.h" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Unverified
Commit
9f8f2c7f
authored
Jan 22, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 22, 2025
Browse files
update norm cu (#3048)
parent
6fc37bd8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
29 deletions
+1
-29
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-1
sgl-kernel/src/sgl-kernel/csrc/norm.cu
sgl-kernel/src/sgl-kernel/csrc/norm.cu
+0
-28
No files found.
sgl-kernel/setup.py
View file @
9f8f2c7f
...
...
@@ -91,7 +91,7 @@ ext_modules = [
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"
src/sgl-kernel
/csrc/norm.cu"
,
"
3rdparty/flashinfer
/csrc/norm.cu"
,
],
include_dirs
=
include_dirs
,
extra_compile_args
=
{
...
...
sgl-kernel/src/sgl-kernel/csrc/norm.cu
deleted
100644 → 0
View file @
6fc37bd8
#include <cstdint>
#include <flashinfer/norm.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
void
rmsnorm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
,
int64_t
cuda_stream
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
weight
);
auto
device
=
input
.
device
();
CHECK_EQ
(
weight
.
device
(),
device
);
CHECK_DIM
(
2
,
input
);
// input: (batch_size, hidden_size)
CHECK_DIM
(
1
,
weight
);
// weight: (hidden_size)
CHECK_EQ
(
input
.
size
(
1
),
weight
.
size
(
0
));
unsigned
int
batch_size
=
input
.
size
(
0
);
unsigned
int
hidden_size
=
input
.
size
(
1
);
CHECK_EQ
(
output
.
size
(
0
),
batch_size
);
CHECK_EQ
(
output
.
size
(
1
),
hidden_size
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
input
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
norm
::
RMSNorm
(
static_cast
<
c_type
*>
(
input
.
data_ptr
()),
static_cast
<
c_type
*>
(
weight
.
data_ptr
()),
static_cast
<
c_type
*>
(
output
.
data_ptr
()),
batch_size
,
hidden_size
,
eps
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"RMSNorm failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
return
true
;
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment