Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c75a3138
Unverified
Commit
c75a3138
authored
Apr 01, 2026
by
Zhanda Zhu
Committed by
GitHub
Apr 01, 2026
Browse files
[Perf] triton bilinear_pos_embed kernel for ViT (#37948)
Signed-off-by:
Zhanda Zhu
<
zhandazhu@gmail.com
>
parent
4f6eed3b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
491 additions
and
54 deletions
+491
-54
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
+162
-0
tests/kernels/core/test_vit_bilinear_pos_embed.py
tests/kernels/core/test_vit_bilinear_pos_embed.py
+120
-0
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+209
-54
No files found.
benchmarks/kernels/benchmark_vit_bilinear_pos_embed.py
0 → 100644
View file @
c75a3138
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Benchmarks the fused Triton bilinear position-embedding kernel against
# the pure-PyTorch (native) implementation used in Qwen3-VL ViT models.
#
# == Usage Examples ==
#
# Default benchmark:
# python3 benchmark_vit_bilinear_pos_embed.py
#
# Custom parameters:
# python3 benchmark_vit_bilinear_pos_embed.py --hidden-dim 1152 \
# --num-grid-per-side 48 --save-path ./configs/vit_pos_embed/
import
itertools
import
torch
from
vllm.model_executor.models.qwen3_vl
import
(
pos_embed_interpolate_native
,
triton_pos_embed_interpolate
,
)
from
vllm.triton_utils
import
HAS_TRITON
,
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
# (h, w) configurations to benchmark
h_w_configs
=
[
(
16
,
16
),
(
32
,
32
),
(
48
,
48
),
(
64
,
64
),
(
128
,
128
),
(
32
,
48
),
(
60
,
80
),
]
# Temporal dimensions
t_range
=
[
1
]
configs
=
list
(
itertools
.
product
(
t_range
,
h_w_configs
))
def
get_benchmark
(
num_grid_per_side
:
int
,
spatial_merge_size
:
int
,
hidden_dim
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"t"
,
"h_w"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"native"
,
"triton"
],
line_names
=
[
"Native (PyTorch)"
,
"Triton"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
(
f
"vit-bilinear-pos-embed-"
f
"grid
{
num_grid_per_side
}
-"
f
"dim
{
hidden_dim
}
-"
f
"
{
dtype
}
"
),
args
=
{},
)
)
def
benchmark
(
t
,
h_w
,
provider
):
h
,
w
=
h_w
torch
.
manual_seed
(
42
)
embed_weight
=
(
torch
.
randn
(
num_grid_per_side
*
num_grid_per_side
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
,
)
*
0.25
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"native"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
pos_embed_interpolate_native
(
embed_weight
,
t
,
h
,
w
,
num_grid_per_side
,
spatial_merge_size
,
dtype
,
),
quantiles
=
quantiles
,
)
else
:
assert
HAS_TRITON
,
"Triton not available"
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
triton_pos_embed_interpolate
(
embed_weight
,
t
,
h
,
w
,
num_grid_per_side
,
spatial_merge_size
,
dtype
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark bilinear position embedding interpolation."
)
parser
.
add_argument
(
"--num-grid-per-side"
,
type
=
int
,
default
=
48
,
help
=
"Position embedding grid size (default: 48 for Qwen3-VL)"
,
)
parser
.
add_argument
(
"--spatial-merge-size"
,
type
=
int
,
default
=
2
,
help
=
"Spatial merge size (default: 2)"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
1152
,
help
=
"Embedding hidden dimension (default: 1152 for Qwen3-VL)"
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
,
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./vit_pos_embed/"
,
)
args
=
parser
.
parse_args
()
dtype
=
torch
.
bfloat16
bench
=
get_benchmark
(
args
.
num_grid_per_side
,
args
.
spatial_merge_size
,
args
.
hidden_dim
,
dtype
,
args
.
device
,
)
bench
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
tests/kernels/core/test_vit_bilinear_pos_embed.py
0 → 100644
View file @
c75a3138
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Accuracy tests for the fused Triton bilinear position-embedding kernel.
Compares ``triton_pos_embed_interpolate`` against the pure-PyTorch
``pos_embed_interpolate_native`` across a variety of grid shapes and dtypes.
"""
import
pytest
import
torch
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.models.qwen3_vl
import
(
pos_embed_interpolate_native
,
triton_pos_embed_interpolate
,
)
DTYPES
=
[
torch
.
float32
,
torch
.
bfloat16
]
# Qwen3-VL default
NUM_GRID_PER_SIDE
=
48
SPATIAL_MERGE_SIZE
=
2
HIDDEN_DIM
=
1152
# 4 square + 4 non-square grids (h, w divisible by spatial_merge_size=2)
SQUARE_GRIDS
=
[(
1
,
4
,
4
),
(
1
,
16
,
16
),
(
1
,
32
,
32
),
(
1
,
48
,
48
)]
NON_SQUARE_GRIDS
=
[(
1
,
8
,
16
),
(
1
,
14
,
20
),
(
1
,
32
,
48
),
(
1
,
60
,
80
)]
ALL_GRIDS
=
SQUARE_GRIDS
+
NON_SQUARE_GRIDS
@
pytest
.
mark
.
skipif
(
not
HAS_TRITON
,
reason
=
"Triton not available"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
,
ids
=
lambda
d
:
str
(
d
).
split
(
"."
)[
-
1
])
@
pytest
.
mark
.
parametrize
(
"grid_thw"
,
ALL_GRIDS
,
ids
=
[
f
"
{
t
}
x
{
h
}
x
{
w
}
"
for
t
,
h
,
w
in
ALL_GRIDS
],
)
def
test_triton_matches_native
(
grid_thw
:
tuple
[
int
,
int
,
int
],
dtype
:
torch
.
dtype
,
)
->
None
:
"""Triton kernel output must match the native PyTorch implementation."""
t
,
h
,
w
=
grid_thw
device
=
"cuda"
# Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23).
torch
.
manual_seed
(
42
)
embed_weight
=
(
torch
.
randn
(
NUM_GRID_PER_SIDE
*
NUM_GRID_PER_SIDE
,
HIDDEN_DIM
,
device
=
device
,
dtype
=
dtype
,
)
*
0.25
)
native_out
=
pos_embed_interpolate_native
(
embed_weight
,
t
,
h
,
w
,
NUM_GRID_PER_SIDE
,
SPATIAL_MERGE_SIZE
,
dtype
)
triton_out
=
triton_pos_embed_interpolate
(
embed_weight
,
t
,
h
,
w
,
NUM_GRID_PER_SIDE
,
SPATIAL_MERGE_SIZE
,
dtype
)
assert
native_out
.
shape
==
triton_out
.
shape
,
(
f
"Shape mismatch: native
{
native_out
.
shape
}
vs triton
{
triton_out
.
shape
}
"
)
# Small numerical differences arise from the precomputed h/w_scale
# in the triton kernel vs torch.linspace in the native path, which can
# cause single-ULP output differences
# in a handful of elements.
atol
=
{
torch
.
float32
:
5e-5
,
torch
.
bfloat16
:
1e-2
}[
dtype
]
rtol
=
{
torch
.
float32
:
1e-5
,
torch
.
bfloat16
:
1e-2
}[
dtype
]
torch
.
testing
.
assert_close
(
triton_out
,
native_out
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
skipif
(
not
HAS_TRITON
,
reason
=
"Triton not available"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
,
ids
=
lambda
d
:
str
(
d
).
split
(
"."
)[
-
1
])
def
test_temporal_repeat
(
dtype
:
torch
.
dtype
)
->
None
:
"""Verify temporal dimension t > 1 correctly repeats the spatial pattern."""
device
=
"cuda"
h
,
w
=
16
,
16
t_single
,
t_multi
=
1
,
3
# Scale to match real Qwen3-VL pos_embed weight distribution (std~0.23).
torch
.
manual_seed
(
42
)
embed_weight
=
(
torch
.
randn
(
NUM_GRID_PER_SIDE
*
NUM_GRID_PER_SIDE
,
HIDDEN_DIM
,
device
=
device
,
dtype
=
dtype
,
)
*
0.25
)
out_single
=
triton_pos_embed_interpolate
(
embed_weight
,
t_single
,
h
,
w
,
NUM_GRID_PER_SIDE
,
SPATIAL_MERGE_SIZE
,
dtype
,
)
out_multi
=
triton_pos_embed_interpolate
(
embed_weight
,
t_multi
,
h
,
w
,
NUM_GRID_PER_SIDE
,
SPATIAL_MERGE_SIZE
,
dtype
,
)
expected
=
out_single
.
repeat
(
t_multi
,
1
)
torch
.
testing
.
assert_close
(
out_multi
,
expected
,
atol
=
0
,
rtol
=
0
)
vllm/model_executor/models/qwen3_vl.py
View file @
c75a3138
...
...
@@ -96,6 +96,7 @@ from vllm.multimodal.processing import (
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers.protocol
import
TokenizerLike
from
vllm.tokenizers.registry
import
cached_tokenizer_from_config
from
vllm.triton_utils
import
HAS_TRITON
,
tl
,
triton
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.math_utils
import
round_up
...
...
@@ -145,6 +146,201 @@ logger = init_logger(__name__)
# of the maximum size.
DUMMY_VIDEO_NUM_FRAMES
=
2048
# ---------------------------------------------------------------------------
# Triton kernel: fused bilinear position-embedding interpolation
# ---------------------------------------------------------------------------
# Replaces many small eager-mode CUDA kernels with a single launch.
# The spatial-merge reorder is baked into the index math so the output
# is ready to be added to the patch embeddings directly.
# ---------------------------------------------------------------------------
if
HAS_TRITON
:
@
triton
.
jit
def
_bilinear_pos_embed_kernel
(
embed_ptr
,
output_ptr
,
H
,
W
,
h_scale
,
w_scale
,
NUM_GRID
:
tl
.
constexpr
,
M_SIZE
:
tl
.
constexpr
,
HIDDEN_DIM
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
"""Fused bilinear pos-embed interpolation with spatial-merge reorder."""
pid
=
tl
.
program_id
(
0
)
total_spatial
=
H
*
W
spatial_idx
=
pid
%
total_spatial
num_blocks_w
=
W
//
M_SIZE
block_idx
=
spatial_idx
//
(
M_SIZE
*
M_SIZE
)
local_idx
=
spatial_idx
%
(
M_SIZE
*
M_SIZE
)
br
=
block_idx
//
num_blocks_w
bc
=
block_idx
%
num_blocks_w
lr
=
local_idx
//
M_SIZE
lc
=
local_idx
%
M_SIZE
row
=
br
*
M_SIZE
+
lr
col
=
bc
*
M_SIZE
+
lc
h_frac
=
row
.
to
(
tl
.
float32
)
*
h_scale
w_frac
=
col
.
to
(
tl
.
float32
)
*
w_scale
hf
=
tl
.
math
.
floor
(
h_frac
).
to
(
tl
.
int32
)
wf
=
tl
.
math
.
floor
(
w_frac
).
to
(
tl
.
int32
)
hc
=
tl
.
minimum
(
hf
+
1
,
NUM_GRID
-
1
)
wc
=
tl
.
minimum
(
wf
+
1
,
NUM_GRID
-
1
)
dh
=
h_frac
-
hf
.
to
(
tl
.
float32
)
dw
=
w_frac
-
wf
.
to
(
tl
.
float32
)
w11
=
dh
*
dw
w10
=
dh
-
w11
w01
=
dw
-
w11
w00
=
1.0
-
dh
-
w01
off00
=
(
hf
*
NUM_GRID
+
wf
)
*
HIDDEN_DIM
off01
=
(
hf
*
NUM_GRID
+
wc
)
*
HIDDEN_DIM
off10
=
(
hc
*
NUM_GRID
+
wf
)
*
HIDDEN_DIM
off11
=
(
hc
*
NUM_GRID
+
wc
)
*
HIDDEN_DIM
out_off
=
pid
*
HIDDEN_DIM
# Cast weights to output dtype so the multiply-accumulate stays
# in the same precision as the native PyTorch implementation.
out_dtype
=
output_ptr
.
dtype
.
element_ty
w00_c
=
w00
.
to
(
out_dtype
)
w01_c
=
w01
.
to
(
out_dtype
)
w10_c
=
w10
.
to
(
out_dtype
)
w11_c
=
w11
.
to
(
out_dtype
)
for
d
in
tl
.
range
(
0
,
HIDDEN_DIM
,
BLOCK_D
):
cols
=
d
+
tl
.
arange
(
0
,
BLOCK_D
)
mask
=
cols
<
HIDDEN_DIM
e00
=
tl
.
load
(
embed_ptr
+
off00
+
cols
,
mask
=
mask
)
e01
=
tl
.
load
(
embed_ptr
+
off01
+
cols
,
mask
=
mask
)
e10
=
tl
.
load
(
embed_ptr
+
off10
+
cols
,
mask
=
mask
)
e11
=
tl
.
load
(
embed_ptr
+
off11
+
cols
,
mask
=
mask
)
val
=
w00_c
*
e00
+
w01_c
*
e01
+
w10_c
*
e10
+
w11_c
*
e11
tl
.
store
(
output_ptr
+
out_off
+
cols
,
val
,
mask
=
mask
)
def
triton_pos_embed_interpolate
(
embed_weight
:
torch
.
Tensor
,
t
:
int
,
h
:
int
,
w
:
int
,
num_grid_per_side
:
int
,
m_size
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
"""Launch the fused Triton kernel for one (t,h,w) grid.
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
bilinearly-interpolated position embeddings in spatial-merge order.
"""
assert
h
%
m_size
==
0
and
w
%
m_size
==
0
,
(
f
"h=
{
h
}
and w=
{
w
}
must be divisible by m_size=
{
m_size
}
"
)
hidden_dim
=
embed_weight
.
shape
[
1
]
total_out
=
t
*
h
*
w
output
=
torch
.
empty
(
total_out
,
hidden_dim
,
device
=
embed_weight
.
device
,
dtype
=
dtype
,
)
h_scale
=
float
(
num_grid_per_side
-
1
)
/
float
(
h
-
1
)
if
h
>
1
else
0.0
w_scale
=
float
(
num_grid_per_side
-
1
)
/
float
(
w
-
1
)
if
w
>
1
else
0.0
BLOCK_D
=
triton
.
next_power_of_2
(
hidden_dim
)
_bilinear_pos_embed_kernel
[(
total_out
,)](
embed_weight
,
output
,
h
,
w
,
h_scale
,
w_scale
,
num_grid_per_side
,
m_size
,
hidden_dim
,
BLOCK_D
,
)
return
output
def
pos_embed_interpolate_native
(
embed_weight
:
torch
.
Tensor
,
t
:
int
,
h
:
int
,
w
:
int
,
num_grid_per_side
:
int
,
m_size
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
"""Eager PyTorch bilinear position-embedding interpolation.
Returns a tensor of shape ``(t * h * w, hidden_dim)`` with the
bilinearly-interpolated position embeddings in spatial-merge order.
"""
assert
h
%
m_size
==
0
and
w
%
m_size
==
0
,
(
f
"h=
{
h
}
and w=
{
w
}
must be divisible by m_size=
{
m_size
}
"
)
hidden_dim
=
embed_weight
.
shape
[
1
]
device
=
embed_weight
.
device
h_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
h
,
dtype
=
torch
.
float32
,
device
=
device
,
)
w_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
w
,
dtype
=
torch
.
float32
,
device
=
device
,
)
h_floor
=
h_idxs
.
to
(
torch
.
long
)
w_floor
=
w_idxs
.
to
(
torch
.
long
)
h_ceil
=
torch
.
clamp
(
h_floor
+
1
,
max
=
num_grid_per_side
-
1
)
w_ceil
=
torch
.
clamp
(
w_floor
+
1
,
max
=
num_grid_per_side
-
1
)
dh
=
h_idxs
-
h_floor
dw
=
w_idxs
-
w_floor
dh_grid
,
dw_grid
=
torch
.
meshgrid
(
dh
,
dw
,
indexing
=
"ij"
)
h_floor_grid
,
w_floor_grid
=
torch
.
meshgrid
(
h_floor
,
w_floor
,
indexing
=
"ij"
)
h_ceil_grid
,
w_ceil_grid
=
torch
.
meshgrid
(
h_ceil
,
w_ceil
,
indexing
=
"ij"
)
w11
=
dh_grid
*
dw_grid
w10
=
dh_grid
-
w11
w01
=
dw_grid
-
w11
w00
=
1
-
dh_grid
-
w01
h_grid
=
torch
.
stack
([
h_floor_grid
,
h_floor_grid
,
h_ceil_grid
,
h_ceil_grid
])
w_grid
=
torch
.
stack
([
w_floor_grid
,
w_ceil_grid
,
w_floor_grid
,
w_ceil_grid
])
h_grid_idx
=
h_grid
*
num_grid_per_side
indices
=
(
h_grid_idx
+
w_grid
).
reshape
(
4
,
-
1
)
weights
=
torch
.
stack
([
w00
,
w01
,
w10
,
w11
],
dim
=
0
).
reshape
(
4
,
-
1
,
1
)
weights
=
weights
.
to
(
dtype
=
dtype
)
embeds
=
embed_weight
[
indices
]
embeds
*=
weights
combined
=
embeds
.
sum
(
dim
=
0
)
combined
=
combined
.
reshape
(
h
//
m_size
,
m_size
,
w
//
m_size
,
m_size
,
hidden_dim
)
combined
=
combined
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
(
1
,
-
1
,
hidden_dim
)
repeated
=
combined
.
expand
(
t
,
-
1
,
-
1
).
reshape
(
-
1
,
hidden_dim
)
return
repeated
.
to
(
dtype
=
dtype
)
class
Qwen3_VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
...
...
@@ -470,63 +666,22 @@ class Qwen3_VisionTransformer(nn.Module):
return
cos_combined
,
sin_combined
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
num_grid_per_side
=
self
.
num_grid_per_side
m_size
=
self
.
spatial_merge_size
hidden_dim
=
self
.
pos_embed
.
embedding_dim
interpolate_fn
=
(
triton_pos_embed_interpolate
if
HAS_TRITON
else
pos_embed_interpolate_native
)
outputs
=
[]
for
t
,
h
,
w
in
grid_thw
:
h_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
h
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
w_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
w
,
dtype
=
torch
.
float32
,
device
=
self
.
device
outputs
.
append
(
interpolate_fn
(
self
.
pos_embed
.
weight
,
t
,
h
,
w
,
self
.
num_grid_per_side
,
self
.
spatial_merge_size
,
self
.
dtype
,
)
h_floor
=
h_idxs
.
to
(
torch
.
long
)
w_floor
=
w_idxs
.
to
(
torch
.
long
)
h_ceil
=
torch
.
clamp
(
h_floor
+
1
,
max
=
num_grid_per_side
-
1
)
w_ceil
=
torch
.
clamp
(
w_floor
+
1
,
max
=
num_grid_per_side
-
1
)
dh
=
h_idxs
-
h_floor
dw
=
w_idxs
-
w_floor
# Create meshgrid view for all h, w vars
dh_grid
,
dw_grid
=
torch
.
meshgrid
(
dh
,
dw
,
indexing
=
"ij"
)
h_floor_grid
,
w_floor_grid
=
torch
.
meshgrid
(
h_floor
,
w_floor
,
indexing
=
"ij"
)
h_ceil_grid
,
w_ceil_grid
=
torch
.
meshgrid
(
h_ceil
,
w_ceil
,
indexing
=
"ij"
)
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11
=
dh_grid
*
dw_grid
w10
=
dh_grid
-
w11
w01
=
dw_grid
-
w11
w00
=
1
-
dh_grid
-
w01
h_grid
=
torch
.
stack
([
h_floor_grid
,
h_floor_grid
,
h_ceil_grid
,
h_ceil_grid
])
w_grid
=
torch
.
stack
([
w_floor_grid
,
w_ceil_grid
,
w_floor_grid
,
w_ceil_grid
])
h_grid_idx
=
h_grid
*
num_grid_per_side
indices
=
(
h_grid_idx
+
w_grid
).
reshape
(
4
,
-
1
)
weights
=
torch
.
stack
([
w00
,
w01
,
w10
,
w11
],
dim
=
0
).
reshape
(
4
,
-
1
,
1
)
weights
=
weights
.
to
(
dtype
=
self
.
dtype
)
embeds
=
self
.
pos_embed
(
indices
)
embeds
*=
weights
combined
=
embeds
.
sum
(
dim
=
0
)
combined
=
combined
.
reshape
(
h
//
m_size
,
m_size
,
w
//
m_size
,
m_size
,
hidden_dim
)
combined
=
combined
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
(
1
,
-
1
,
hidden_dim
)
repeated
=
combined
.
expand
(
t
,
-
1
,
-
1
).
reshape
(
-
1
,
hidden_dim
)
outputs
.
append
(
repeated
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
prepare_encoder_metadata
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment