Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
bd29cf3d
"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "16fba4c095843821e544b15d17a610c5e2541bce"
Unverified
Commit
bd29cf3d
authored
Dec 20, 2023
by
Antoni Baum
Committed by
GitHub
Dec 20, 2023
Browse files
Remove Sampler copy stream (#2209)
parent
31bff691
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
9 deletions
+4
-9
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-9
No files found.
vllm/model_executor/layers/sampler.py
View file @
bd29cf3d
...
@@ -30,7 +30,6 @@ class Sampler(nn.Module):
...
@@ -30,7 +30,6 @@ class Sampler(nn.Module):
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
_copy_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -51,14 +50,10 @@ class Sampler(nn.Module):
...
@@ -51,14 +50,10 @@ class Sampler(nn.Module):
# Apply logits processors (if any).
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
# Prepare sampling tensors in another stream to overlap
# Prepare sampling tensors with pinned memory to avoid blocking.
# CPU<->GPU data transfer with GPU computation in forward pass.
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
with
torch
.
cuda
.
stream
(
self
.
_copy_stream
):
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_copy_stream
)
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
if
do_penalties
:
if
do_penalties
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment