Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
cfc15a10
Unverified
Commit
cfc15a10
authored
Feb 26, 2024
by
Philipp Moritz
Committed by
GitHub
Feb 26, 2024
Browse files
Optimize Triton MoE Kernel (#2979)
Co-authored-by:
Cade Daniel
<
edacih@gmail.com
>
parent
70f3e8e3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
297 additions
and
15 deletions
+297
-15
benchmarks/kernels/benchmark_mixtral_moe.py
benchmarks/kernels/benchmark_mixtral_moe.py
+172
-0
setup.py
setup.py
+3
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+5
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
...configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
+20
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
...configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
+24
-0
vllm/model_executor/layers/fused_moe/configs/README
vllm/model_executor/layers/fused_moe/configs/README
+10
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+63
-14
No files found.
benchmarks/kernels/benchmark_mixtral_moe.py
0 → 100644
View file @
cfc15a10
import
json
import
os
import
sys
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
vllm.model_executor.layers.fused_moe
import
fused_moe
import
torch
import
torch.nn.functional
as
F
import
triton
def
main
():
method
=
fused_moe
for
bs
in
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
]:
run_grid
(
bs
,
method
=
method
)
def
run_grid
(
bs
,
method
):
d_model
=
4096
num_total_experts
=
8
top_k
=
2
tp_size
=
2
model_intermediate_size
=
14336
num_layers
=
32
num_calls
=
100
num_warmup_trials
=
1
num_trials
=
1
configs
=
[]
if
bs
<=
16
:
BLOCK_SIZES_M
=
[
16
]
elif
bs
<=
32
:
BLOCK_SIZES_M
=
[
16
,
32
]
elif
bs
<=
64
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
]
elif
bs
<=
128
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
,
128
]
else
:
BLOCK_SIZES_M
=
[
16
,
32
,
64
,
128
,
256
]
for
block_size_n
in
[
32
,
64
,
128
,
256
]:
for
block_size_m
in
BLOCK_SIZES_M
:
for
block_size_k
in
[
64
,
128
,
256
]:
for
group_size_m
in
[
1
,
16
,
32
,
64
]:
for
num_warps
in
[
4
,
8
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_K"
:
block_size_k
,
"GROUP_SIZE_M"
:
group_size_m
,
"num_warps"
:
num_warps
,
"num_stages"
:
4
,
})
best_config
=
None
best_time_us
=
1e20
for
config
in
configs
:
print
(
f
'
{
tp_size
=
}
{
bs
=
}
'
)
print
(
f
'
{
config
}
'
)
# warmup
print
(
f
'warming up'
)
try
:
for
_
in
range
(
num_warmup_trials
):
run_timing
(
num_calls
=
num_calls
,
bs
=
bs
,
d_model
=
d_model
,
num_total_experts
=
num_total_experts
,
top_k
=
top_k
,
tp_size
=
tp_size
,
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
config
=
config
,
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
continue
# trial
print
(
f
'benchmarking'
)
for
_
in
range
(
num_trials
):
kernel_dur_ms
=
run_timing
(
num_calls
=
num_calls
,
bs
=
bs
,
d_model
=
d_model
,
num_total_experts
=
num_total_experts
,
top_k
=
top_k
,
tp_size
=
tp_size
,
model_intermediate_size
=
model_intermediate_size
,
method
=
method
,
config
=
config
,
)
kernel_dur_us
=
1000
*
kernel_dur_ms
model_dur_ms
=
kernel_dur_ms
*
num_layers
if
kernel_dur_us
<
best_time_us
:
best_config
=
config
best_time_us
=
kernel_dur_us
print
(
f
'
{
kernel_dur_us
=
:.
1
f
}
{
model_dur_ms
=
:.
1
f
}
{
bs
=
}
{
tp_size
=
}
{
top_k
=
}
{
num_total_experts
=
}
{
d_model
=
}
{
model_intermediate_size
=
}
{
num_layers
=
}
'
)
print
(
"best_time_us"
,
best_time_us
)
print
(
"best_config"
,
best_config
)
filename
=
"/tmp/config.jsonl"
print
(
f
"writing config to file
{
filename
}
"
)
with
open
(
filename
,
"a"
)
as
f
:
f
.
write
(
json
.
dumps
({
str
(
bs
):
best_config
})
+
"
\n
"
)
def
run_timing
(
num_calls
:
int
,
bs
:
int
,
d_model
:
int
,
num_total_experts
:
int
,
top_k
:
int
,
tp_size
:
int
,
model_intermediate_size
:
int
,
method
,
config
)
->
float
:
shard_intermediate_size
=
model_intermediate_size
//
tp_size
hidden_states
=
torch
.
rand
(
(
bs
,
d_model
),
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
,
)
ws
=
torch
.
rand
(
(
num_total_experts
,
2
*
shard_intermediate_size
,
d_model
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
w2s
=
torch
.
rand
(
(
num_total_experts
,
d_model
,
shard_intermediate_size
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
gating_output
=
F
.
softmax
(
torch
.
rand
(
(
num_calls
,
bs
,
num_total_experts
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
float32
,
),
dim
=-
1
)
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
.
record
()
for
i
in
range
(
num_calls
):
hidden_states
=
method
(
hidden_states
=
hidden_states
,
w1
=
ws
,
w2
=
w2s
,
gating_output
=
gating_output
[
i
],
topk
=
2
,
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
)
end_event
.
record
()
end_event
.
synchronize
()
dur_ms
=
start_event
.
elapsed_time
(
end_event
)
/
num_calls
return
dur_ms
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
setup.py
View file @
cfc15a10
...
...
@@ -432,7 +432,9 @@ def get_requirements() -> List[str]:
return
requirements
package_data
=
{
"vllm"
:
[
"py.typed"
]}
package_data
=
{
"vllm"
:
[
"py.typed"
,
"model_executor/layers/fused_moe/configs/*.json"
]
}
if
os
.
environ
.
get
(
"VLLM_USE_PRECOMPILED"
):
ext_modules
=
[]
package_data
[
"vllm"
].
append
(
"*.so"
)
...
...
vllm/model_executor/layers/fused_moe/__init__.py
0 → 100644
View file @
cfc15a10
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
__all__
=
[
"fused_moe"
,
]
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
cfc15a10
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
cfc15a10
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"80"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"200"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"208"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"216"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"224"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/README
0 → 100644
View file @
cfc15a10
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
vllm/model_executor/layers/fused_moe.py
→
vllm/model_executor/layers/fused_moe
/fused_moe
.py
View file @
cfc15a10
"""Fused MoE kernel."""
import
functools
import
json
import
os
from
typing
import
Any
,
Dict
,
Optional
import
torch
import
triton
import
triton.language
as
tl
from
vllm._C
import
ops
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fused_moe_kernel
(
...
...
@@ -210,6 +218,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}
.json"
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
f
"Using configuration from
{
config_file_path
}
for MoE layer."
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default configuration
return
None
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -218,6 +254,7 @@ def fused_moe(
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...
...
@@ -230,6 +267,7 @@ def fused_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
...
...
@@ -279,20 +317,31 @@ def fused_moe(
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
topk_ids
.
numel
()
<=
w1
.
shape
[
0
]:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
])
if
configs
:
# If an optimal configuration map has been found, look up the optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment