Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c245b789
"examples/pytorch/geniepath.py" did not exist on "c87564d5536dbec9870dfef5b43e5b5d97a36b25"
Unverified
Commit
c245b789
authored
Aug 11, 2024
by
Yineng Zhang
Committed by
GitHub
Aug 11, 2024
Browse files
hotfix: add CustomOp abstraction (#1027)
parent
9dae4078
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
4 deletions
+7
-4
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+4
-2
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+3
-2
No files found.
python/sglang/srt/layers/activation.py
View file @
c245b789
...
@@ -13,15 +13,17 @@ limitations under the License.
...
@@ -13,15 +13,17 @@ limitations under the License.
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
flashinfer.activation
import
silu_and_mul
from
flashinfer.activation
import
silu_and_mul
from
vllm.model_executor.custom_op
import
CustomOp
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
...
...
python/sglang/srt/layers/layernorm.py
View file @
c245b789
...
@@ -18,9 +18,10 @@ from typing import Optional, Tuple, Union
...
@@ -18,9 +18,10 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
vllm.model_executor.custom_op
import
CustomOp
class
RMSNorm
(
nn
.
Module
):
class
RMSNorm
(
CustomOp
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -30,7 +31,7 @@ class RMSNorm(nn.Module):
...
@@ -30,7 +31,7 @@ class RMSNorm(nn.Module):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
def
forward
_cuda
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment