Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f2938031
Unverified
Commit
f2938031
authored
Aug 09, 2024
by
UnicornChan
Committed by
GitHub
Aug 09, 2024
Browse files
Merge pull request #27 from chenht2022/develop-0.1.2
[Feature] towards 0.1.2
parents
442e13bc
c1cc7d2c
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+4
-4
No files found.
ktransformers/operators/experts.py
View file @
f2938031
...
...
@@ -155,7 +155,7 @@ class MLPCPUExperts(MLPExpertsBase):
self
.
moe
=
MOE
(
moe_config
)
self
.
cpu_infer
=
MLPCPUExperts
.
CPU_INFER
if
warmup
:
self
.
cpu_infer
.
submit
(
self
.
moe
.
warm_up
)
self
.
cpu_infer
.
submit
(
self
.
moe
.
warm_up
()
)
self
.
cpu_infer
.
sync
()
if
MLPCPUExperts
.
output_gpu
==
None
:
MLPCPUExperts
.
input_tensor_cpu
=
torch
.
empty
((
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
)
...
...
@@ -168,7 +168,7 @@ class MLPCPUExperts(MLPExpertsBase):
MLPCPUExperts
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
MLPCPUExperts
.
expert_ids_cpu
.
copy_
(
expert_ids
,
non_blocking
=
True
)
MLPCPUExperts
.
weights_cpu
.
copy_
(
weights
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
,
1
,
expert_ids
.
size
(
0
),
MLPCPUExperts
.
expert_ids_cpu
.
data_ptr
(),
MLPCPUExperts
.
weights_cpu
.
data_ptr
(),
MLPCPUExperts
.
input_tensor_cpu
.
data_ptr
(),
MLPCPUExperts
.
output_cpu
.
data_ptr
())
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
(
1
,
expert_ids
.
size
(
0
),
MLPCPUExperts
.
expert_ids_cpu
.
data_ptr
(),
MLPCPUExperts
.
weights_cpu
.
data_ptr
(),
MLPCPUExperts
.
input_tensor_cpu
.
data_ptr
(),
MLPCPUExperts
.
output_cpu
.
data_ptr
())
)
def
sync_for_one_decode
(
self
):
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
...
...
@@ -183,7 +183,7 @@ class MLPCPUExperts(MLPExpertsBase):
MLPCPUExperts
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
MLPCPUExperts
.
expert_ids_cpu
.
copy_
(
expert_ids
,
non_blocking
=
True
)
MLPCPUExperts
.
weights_cpu
.
copy_
(
weights
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
,
1
,
expert_ids
.
size
(
1
),
MLPCPUExperts
.
expert_ids_cpu
.
data_ptr
(),
MLPCPUExperts
.
weights_cpu
.
data_ptr
(),
MLPCPUExperts
.
input_tensor_cpu
.
data_ptr
(),
MLPCPUExperts
.
output_cpu
.
data_ptr
())
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
(
1
,
expert_ids
.
size
(
1
),
MLPCPUExperts
.
expert_ids_cpu
.
data_ptr
(),
MLPCPUExperts
.
weights_cpu
.
data_ptr
(),
MLPCPUExperts
.
input_tensor_cpu
.
data_ptr
(),
MLPCPUExperts
.
output_cpu
.
data_ptr
())
)
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
MLPCPUExperts
.
output_gpu
.
copy_
(
MLPCPUExperts
.
output_cpu
,
non_blocking
=
True
)
#print("capturing experts finish")
...
...
@@ -193,7 +193,7 @@ class MLPCPUExperts(MLPExpertsBase):
expert_ids
=
expert_ids
.
contiguous
().
cpu
()
weights
=
weights
.
contiguous
().
to
(
torch
.
float32
).
cpu
()
output
=
torch
.
empty_like
(
input_tensor
).
contiguous
()
self
.
cpu_infer
.
submit
(
self
.
moe
.
forward
,
expert_ids
.
size
(
0
),
expert_ids
.
size
(
1
),
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input_tensor
.
data_ptr
(),
output
.
data_ptr
())
self
.
cpu_infer
.
submit
(
self
.
moe
.
forward
(
expert_ids
.
size
(
0
),
expert_ids
.
size
(
1
),
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input_tensor
.
data_ptr
(),
output
.
data_ptr
())
)
self
.
cpu_infer
.
sync
()
return
output
.
to
(
device
=
object
.
__getattribute__
(
self
,
"device"
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment