Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c676e3d
Commit
4c676e3d
authored
Jun 20, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.1' into v0.9.1-dev
parents
b4c4464d
b6553be1
Changes
418
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1472 additions
and
1221 deletions
+1472
-1221
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
+0
-31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
+0
-20
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
+0
-31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
+0
-18
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+0
-588
csrc/moe/marlin_moe_wna16/.gitignore
csrc/moe/marlin_moe_wna16/.gitignore
+1
-0
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+23
-9
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+16
-17
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+291
-281
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+205
-192
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+4
-4
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+7
-1
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+230
-0
csrc/moe/moe_wna16_utils.h
csrc/moe/moe_wna16_utils.h
+8
-8
csrc/moe/permute_unpermute_kernels/dispatch.h
csrc/moe/permute_unpermute_kernels/dispatch.h
+59
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+231
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+95
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+211
-0
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+57
-18
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+34
-3
No files found.
Too many changes to show.
To preserve performance only
418 of 418+
files are displayed.
Plain diff
Email patch
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
deleted
100644 → 0
View file @
b4c4464d
#include "marlin_moe_kernel_ku4b8.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
deleted
100644 → 0
View file @
b4c4464d
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
deleted
100644 → 0
View file @
b4c4464d
#include "marlin_moe_kernel_ku8b128.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
deleted
100644 → 0
View file @
b4c4464d
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
csrc/moe/marlin_moe_ops.cu
deleted
100644 → 0
View file @
b4c4464d
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin_moe
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
blockDim
.
x
;
int
rest
=
size_k
%
blockDim
.
x
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
blockDim
.
x
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
int
expert_id
=
threadIdx
.
x
;
int
num_experts
=
blockDim
.
x
;
int
occurrences
=
0
;
for
(
int
i
=
0
;
i
<
topk_length
;
++
i
)
{
occurrences
+=
(
topk_ids
[
i
]
==
expert_id
);
}
expert_offsets
[
expert_id
+
1
]
=
occurrences
;
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tot_offset
+=
ceildiv
(
expert_offsets
[
i
+
1
],
block_size
)
*
block_size
;
expert_offsets
[
i
+
1
]
=
tot_offset
;
}
}
__syncthreads
();
}
#else
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
#endif
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
// Default
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
{
64
,
64
,
128
},
// Reduce both 2X
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
{
64
,
64
,
128
},
// Reduce N 4X, same K
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_k
=
th_config
.
thread_k
;
// Get max scale groups per thread-block
int
tb_groups
;
if
(
group_size
==
-
1
)
{
tb_groups
=
1
;
}
else
if
(
group_size
==
0
)
{
tb_groups
=
ceildiv
(
tb_k
,
32
);
// Worst case is 32 group size
}
else
{
tb_groups
=
ceildiv
(
tb_k
,
group_size
);
}
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
STAGES
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
4
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
STAGES
;
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
ceildiv
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
STAGES
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if
(
th_config
.
thread_k
!=
128
&&
th_config
.
thread_k
!=
64
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void
marlin_mm_moe
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
topk_ids
,
const
void
*
s
,
void
*
zp
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
int
num_bits
=
q_type
.
size_bits
();
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
USER_THREADS
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
group_size
!=
-
1
);
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
else
{
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
}
else
{
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
}
int
tot_m
=
prob_m
;
const
int
*
topk_ids_ptr
=
(
const
int
*
)
topk_ids
;
int
*
expert_offsets_ptr
=
(
int
*
)
expert_offsets
;
compute_expert_offsets
<<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids_ptr
,
expert_offsets_ptr
,
tot_m
*
topk
,
moe_block_size
);
bool
do_permute_a
=
has_act_order
;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
}
int
pack_factor
=
32
/
q_type
.
size_bits
();
for
(
int
expert_idx
=
0
;
expert_idx
<
num_experts
;
++
expert_idx
)
{
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
+
(
prob_n
*
prob_k
/
(
pack_factor
*
4
))
*
expert_idx
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
num_groups
*
prob_n
/
8
*
expert_idx
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
+
num_groups
*
prob_n
/
(
pack_factor
*
4
)
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
int
*
locks
=
(
int
*
)
workspace
;
if
(
do_permute_a
)
{
// Permute A columns
int
topk_rows
=
replicate_input
?
tot_m
:
tot_m
*
topk
;
int
block_rows
=
ceildiv
(
topk_rows
,
blocks
);
permute_cols_kernel
<<<
blocks
,
num_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
topk_rows
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
int
tot_m_blocks
=
ceildiv
(
tot_m
,
16
);
for
(
int
m_block
=
0
;
m_block
<
tot_m_blocks
;
m_block
+=
4
*
exec_cfg
.
max_m_blocks
)
{
if
(
false
)
{
}
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4b8
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku8b128
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
}
}
}
}
// namespace marlin_moe
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
bool
has_zp
=
b_zeros
.
size
(
1
)
!=
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
.
str
());
}
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
int
max_par
=
4
;
int
dev
=
a
.
get_device
();
auto
options_dtype
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
zeros
({
size_m
,
topk
,
size_n
},
options_dtype
);
torch
::
Tensor
a_tmp
=
replicate_input
?
torch
::
zeros
({
size_m
,
size_k
},
options_dtype
)
:
torch
::
zeros
({
size_m
,
topk
,
size_k
},
options_dtype
);
torch
::
Tensor
expert_offsets
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
1
)
!=
0
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
3
,
"b_scales rank = "
,
b_rank
,
" is not 3"
);
TORCH_CHECK
(
b_scales
.
size
(
2
)
==
size_n
,
"b_scales dim 2 = "
,
b_scales
.
size
(
2
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
1
);
TORCH_CHECK
(
VLLM_IMPLIES
(
!
is_k_full
,
has_act_order
),
"if is_k_full is false, has_act_order must be true"
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by num_groups = "
,
num_groups
);
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
0
;
}
}
else
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
-
1
;
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
3
,
"b_zeros rank = "
,
rank
,
" is not 3"
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
num_groups
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
2
)
==
size_n
/
pack_factor
,
"b_zeros dim 2 = "
,
b_zeros
.
size
(
2
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
marlin_moe
::
marlin_mm_moe
(
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm_moe"
,
&
marlin_gemm_moe
);
}
csrc/moe/marlin_moe_wna16/.gitignore
0 → 100644
View file @
4c676e3d
kernel_*.cu
\ No newline at end of file
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
4c676e3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
glob
import
itertools
import
os
...
...
@@ -25,15 +26,16 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
]
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
...
...
@@ -41,7 +43,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
2
,
4
,
8
]
GROUP_BLOCKS
=
[
0
,
-
1
,
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
...
...
@@ -52,21 +54,35 @@ def remove_old_kernels():
def
generate_new_kernels
():
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
has_zp
=
"B"
not
in
scalar_type
all_template_str_list
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
has_act_order
=
group_blocks
==
0
if
has_zp
and
has_act_order
:
# act order case only support gptq-int4 and gptq-int8
if
group_blocks
==
0
and
scalar_type
not
in
[
"vllm::kU4B8"
,
"vllm::kU8B128"
]:
continue
if
thread_configs
[
2
]
==
256
:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if
m_blocks
<=
1
and
thread_configs
[
0
]
!=
128
:
continue
if
m_blocks
>
1
and
thread_configs
[
0
]
!=
64
:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
continue
# nvfp4 only supports group_size == 16
if
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
not
in
[
1
,
2
]:
continue
# other quantization methods don't support group_size = 16
if
scalar_type
!=
"vllm::kFE2M1f"
and
group_blocks
==
1
:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
threads
=
thread_configs
[
2
]
...
...
@@ -82,8 +98,6 @@ def generate_new_kernels():
thread_k_blocks
=
k_blocks
,
m_block_size_8
=
m_blocks
==
0.5
,
stages
=
"pipe_stages"
,
has_act_order
=
has_act_order
,
has_zp
=
has_zp
,
group_blocks
=
group_blocks
,
is_zp_float
=
False
,
)
...
...
csrc/moe/marlin_moe_wna16/kernel.h
View file @
4c676e3d
...
...
@@ -7,18 +7,19 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
...
...
@@ -33,11 +34,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
MARLIN_KERNEL_PARAMS
);
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
4c676e3d
...
...
@@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -77,8 +76,8 @@ __global__ void Marlin(
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{}
}
// namespace MARLIN_NAMESPACE_NAME
...
...
@@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
template
<
typename
scalar_t
,
int
bit
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant
(
int
q
,
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
4
>
(
int
q
,
typename
ScalarType
<
half
>::
FragB
&
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
4
>
(
int
q
,
typename
ScalarType
<
nv_bfloat16
>::
FragB
&
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
8
>
(
int
q
,
typename
ScalarType
<
half
>::
FragB
&
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
8
>
(
int
q
,
typename
ScalarType
<
nv_bfloat16
>::
FragB
&
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
...
...
@@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -442,9 +301,11 @@ __global__ void Marlin(
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
uint16_t
*
__restrict__
scale2_ptr
,
// fp16 global scale (for nvfp4
// only)
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int32_t
*
__restrict__
sorted_token_ids_ptr
,
// moe sorted_ids
const
int32_t
*
__restrict__
expert_ids_ptr
,
// moe expert ids
const
int32_t
*
__restrict__
num_tokens_past_padded_ptr
,
// moe num tokens
...
...
@@ -458,8 +319,8 @@ __global__ void Marlin(
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
...
...
@@ -481,13 +342,26 @@ __global__ void Marlin(
extern
__shared__
int4
sh
[];
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
bool
has_zp
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
;
constexpr
bool
is_int_type
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
||
w_type
==
vllm
::
kU4B8
||
w_type
==
vllm
::
kU8B128
;
// see comments of dequant.h for more details
constexpr
bool
dequant_skip_flop
=
!
is_int_type
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
w_type
==
vllm
::
kU8
);
scalar_t2
global_scale
;
constexpr
bool
has_act_order
=
group_blocks
==
0
;
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
static_assert
(
thread_m_blocks
==
1
||
!
m_block_size_8
);
constexpr
int
moe_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
const
int
group_size
=
(
!
has_act_order
&&
group_blocks
==
-
1
)
?
prob_k
:
prob_k
/
num_groups
;
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
8
;
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
(
w_type
==
vllm
::
kFE2M1f
?
16
:
8
);
const
int
zp_expert_stride
=
is_zp_float
?
prob_n
*
prob_k
/
group_size
/
8
:
prob_n
*
prob_k
/
group_size
/
(
pack_factor
*
4
);
...
...
@@ -534,13 +408,20 @@ __global__ void Marlin(
int64_t
B_expert_off
=
0
;
int4
*
sh_block_sorted_ids_int4
=
sh
;
int4
*
sh_rd_block_sorted_ids_int4
=
sh_block_sorted_ids_int4
+
moe_block_size
/
4
;
int4
*
sh_block_topk_weights_int4
=
sh_rd_block_sorted_ids_int4
+
moe_block_size
/
4
;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4
*
sh_new
=
sh_block_topk_weights_int4
+
moe_block_size
/
2
+
moe_block_size
;
int32_t
*
sh_block_sorted_ids
=
reinterpret_cast
<
int
*>
(
sh_block_sorted_ids_int4
);
int
4
*
sh_block_
topk_weights_int4
=
sh_block_sorted_ids_int4
+
moe_block_size
/
4
;
int
32_t
*
sh_
rd_
block_
sorted_ids
=
reinterpret_cast
<
int
*>
(
sh_rd_block_sorted_ids_int4
)
;
scalar_t2
*
sh_block_topk_weights
=
reinterpret_cast
<
scalar_t2
*>
(
sh_block_topk_weights_int4
);
int4
*
sh_new
=
sh_block_topk_weights_int4
+
moe_block_size
/
4
;
int32_t
block_num_valid_tokens
=
0
;
int32_t
locks_off
=
0
;
...
...
@@ -584,12 +465,24 @@ __global__ void Marlin(
sh_block_sorted_ids_int4
[
tid4
]
=
reinterpret_cast
<
const
int4
*>
(
sorted_token_ids_ptr
)[
block_id
*
moe_block_size
/
4
+
tid4
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
sh_rd_block_sorted_ids
[
tid4
*
4
+
i
]
=
sh_block_sorted_ids
[
tid4
*
4
+
i
]
/
top_k
;
if
(
mul_topk_weights
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]]));
int
idx
=
tid4
*
4
+
i
;
idx
=
idx
<
block_num_valid_tokens
?
idx
:
0
;
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
sh_block_topk_weights
[
idx
]
=
__hmul2
(
global_scale
,
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
idx
]])));
}
else
{
sh_block_topk_weights
[
idx
]
=
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
idx
]]));
}
}
}
}
...
...
@@ -620,6 +513,11 @@ __global__ void Marlin(
expert_id
=
expert_ids_ptr
[
block_id
];
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
uint16_t
val
=
scale2_ptr
[
expert_id
];
global_scale
=
Dtype
::
num2num2
(
*
reinterpret_cast
<
scalar_t
*>
(
&
val
));
}
B_expert_off
=
expert_id
*
prob_n
*
prob_k
/
(
pack_factor
*
4
);
scales_ptr
+=
(
expert_id
-
old_expert_id
)
*
scales_expert_stride
;
if
constexpr
(
has_zp
)
{
...
...
@@ -733,7 +631,7 @@ __global__ void Marlin(
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
...
...
@@ -743,6 +641,7 @@ __global__ void Marlin(
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr
int
act_s_max_num_groups
=
32
;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
...
...
@@ -758,9 +657,9 @@ __global__ void Marlin(
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
)
;
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
int
a_gl_rd
_row
=
threadIdx
.
x
/
a_gl_rd_delta_o
;
int
a_gl_rd_col
=
a_gl_rd_delta_o
*
slice_row
+
threadIdx
.
x
%
a_gl_rd_delta_o
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
...
...
@@ -774,8 +673,8 @@ __global__ void Marlin(
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
...
...
@@ -790,11 +689,12 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
auto
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
...
...
@@ -807,17 +707,27 @@ __global__ void Marlin(
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
auto
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
if
constexpr
(
group_blocks
!=
-
1
&&
w_type
==
vllm
::
kFE2M1f
)
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
has_zp
))
s_sh_rd
=
s_sh_rd
*
2
+
warp_row
%
2
;
}
else
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
(
has_zp
&&
!
dequant_skip_flop
)))
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
8
;
else
...
...
@@ -851,7 +761,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
(
row
%
8
)
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
...
...
@@ -879,12 +789,28 @@ __global__ void Marlin(
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh_new
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
constexpr
int
sh_red_size
=
(
2
*
thread_n_blocks
+
1
)
*
16
*
thread_m_blocks
;
constexpr
int
sh_b_size
=
stages
*
b_sh_stage
;
int4
*
sh_b
=
sh_new
;
int4
*
sh_red
=
sh_new
;
int4
*
sh_g_idx
=
sh_b
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
constexpr
int
sh_s_size
=
has_act_order
?
(
act_s_max_num_groups
*
s_sh_stride
)
:
(
stages
*
s_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_b
;
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert
(
thread_m_blocks
*
16
*
thread_n_blocks
*
16
/
8
<=
stages
*
b_sh_stage
);
int4
*
sh_a
=
sh_s
+
sh_s_size
;
constexpr
int
shm_size_used
=
moe_block_size
+
stages
*
(
g_idx_stage
+
zp_sh_stage
)
+
sh_s_size
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int
sh_a_max_row
=
((
max_shared_mem
-
1024
)
/
16
-
shm_size_used
)
/
(
thread_k_blocks
*
2
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
...
...
@@ -905,15 +831,14 @@ __global__ void Marlin(
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_act_order_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh
_max_num_groups
)
{
sh_num_groups
=
s
h
_max_num_groups
;
if
(
sh_num_groups
>
act_s
_max_num_groups
)
{
sh_num_groups
=
act_
s_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
...
...
@@ -940,27 +865,31 @@ __global__ void Marlin(
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int
a_remaining_load_count_in_slice
=
stages
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
bool
should_load_a
=
true
;
int
max_num_stage_groups
=
((
sh_a_max_row
-
moe_block_size
)
/
moe_block_size
+
1
)
/
stages
;
max_num_stage_groups
=
max
(
max_num_stage_groups
,
1
);
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
,
int
pipe_a
=
0
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
if
(
prob_k
>
thread_k_blocks
*
16
*
stages
||
slice_col
==
0
||
a_remaining_load_count_in_slice
>
0
)
{
a_remaining_load_count_in_slice
--
;
if
(
should_load_a
)
{
int4
*
sh_a_stage
=
sh_a
+
moe_block_size
*
a_sh_stride
*
pipe_a
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
;
int
row
=
a_idx
/
a_gl_stride
;
int
row
=
a_gl_rd_delta_i
/
a_gl_stride
*
i
+
a_gl_rd_row
;
int64_t
sorted_row
=
0
;
if
(
!
m_block_size_8
||
row
<
8
)
sorted_row
=
sh_block_sorted_ids
[
row
]
/
top_k
;
int64_t
true_idx
=
sorted_row
*
a_gl_stride
+
a_idx
%
a_gl_stride
;
sorted_row
=
sh_rd_block_sorted_ids
[
row
];
int64_t
true_idx
=
sorted_row
*
a_gl_stride
+
a_gl_rd_col
+
a_gl_rd_delta_o
*
a_off
;
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
true_idx
],
row
<
block_num_valid_tokens
);
}
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
...
...
@@ -1063,8 +992,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_st
ag
e
*
pipe
;
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
,
int
pipe_a
=
0
)
{
int4
*
sh_a_stage
=
sh_a
+
moe_block_size
*
a_sh_st
rid
e
*
pipe
_a
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm
<
m_block_size_8
?
2
:
4
,
scalar_t
>
(
...
...
@@ -1109,12 +1038,17 @@ __global__ void Marlin(
}
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1123,12 +1057,19 @@ __global__ void Marlin(
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
cur_group_id
=
k_blocks
/
(
group_blocks
*
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
));
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
if
constexpr
(
w_type_id
!=
vllm
::
kFE2M1f
.
id
())
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
else
{
reinterpret_cast
<
int2
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
reinterpret_cast
<
int2
*>
(
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
}
}
}
...
...
@@ -1152,7 +1093,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
...
...
@@ -1161,7 +1102,7 @@ __global__ void Marlin(
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
auto
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
...
...
@@ -1222,15 +1163,18 @@ __global__ void Marlin(
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1251,6 +1195,7 @@ __global__ void Marlin(
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
...
...
@@ -1263,12 +1208,16 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1292,6 +1241,10 @@ __global__ void Marlin(
}
};
auto
dequant_data
=
[
&
](
int
q
,
scalar_t2
*
frag_b_ptr
)
{
dequant
<
scalar_t2
,
w_type_id
,
dequant_skip_flop
>
(
q
,
frag_b_ptr
);
};
// Execute the actual tensor core matmul of a sub-tile.
bool
is_first_matmul_in_slice
=
true
;
auto
matmul
=
[
&
](
int
k
)
{
...
...
@@ -1315,15 +1268,27 @@ __global__ void Marlin(
zp_quant_1
=
frag_qzp
[
k2
][
1
];
}
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
zp_quant_0
,
frag_zp_0
);
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
zp_quant_1
,
frag_zp_1
);
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
dequant_data
(
zp_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
));
dequant_data
(
zp_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
)
+
2
);
}
}
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
is_zp_float
)
{
if
(
is_new_zp
)
{
reinterpret_cast
<
int4
*>
(
&
frag_zp
)[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k2
])[
0
];
}
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
]));
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
])
+
2
);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
...
...
@@ -1332,7 +1297,10 @@ __global__ void Marlin(
FragB
frag_b1
;
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
if
constexpr
(
w_type_id
==
vllm
::
kFE2M1f
.
id
())
{
b_quant_1
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_0
=
b_quant_1
<<
8
;
}
else
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
...
...
@@ -1342,8 +1310,13 @@ __global__ void Marlin(
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
b_quant_0
,
frag_b0
);
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
b_quant_1
,
frag_b1
);
dequant_data
(
b_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b0
));
dequant_data
(
b_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b1
));
if
constexpr
(
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
...
...
@@ -1351,9 +1324,9 @@ __global__ void Marlin(
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
0
);
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
act_frag_s
[
k
2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
s2
=
Dtype
::
nums2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
[
j
/
2
][
j
%
2
*
2
+
0
])[
idx
],
...
...
@@ -1361,18 +1334,12 @@ __global__ void Marlin(
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
s2
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
s2
.
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
s2
.
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_z
p
&&
!
i
s_zp
_float
&&
group_blocks
!=
-
1
)
{
}
else
if
constexpr
(
!
dequant_skip_flo
p
&&
ha
s_zp
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k
%
2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k
%
2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zpf
[
k2
][
j
]
=
__hmul2
(
frag_zpf
[
k2
][
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
].
x
,
frag_zpf
[
k2
][
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
].
y
,
frag_zpf
[
k2
][
j
].
y
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
],
0
);
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
],
1
);
...
...
@@ -1397,7 +1364,7 @@ __global__ void Marlin(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
...
...
@@ -1634,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
4
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
if
(
!
mul_topk_weights
)
{
res
=
__hmul2
(
res
,
global_scale
);
}
}
if
constexpr
(
m_block_size_8
)
{
((
scalar_t
*
)
sh_red
)[
idx
]
=
res
.
x
;
((
scalar_t
*
)
sh_red
)[
idx
+
8
*
c_sh_stride
]
=
res
.
y
;
...
...
@@ -1728,10 +1702,12 @@ __global__ void Marlin(
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_col_zp_to_shared
();
fetch_col_scale_to_shared
();
if
constexpr
(
!
dequant_skip_flop
)
{
fetch_col_scale_to_shared
();
}
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
,
i
);
}
zero_accums
();
...
...
@@ -1740,8 +1716,10 @@ __global__ void Marlin(
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
a_gl_rd_col
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
if
constexpr
(
has_act_order
)
{
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
}
};
if
(
slice_iters
)
{
start_pipes
();
...
...
@@ -1754,43 +1732,59 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at
// index 0.
for
(
int
stage_group_id
=
0
;
stage_group_id
<
max_num_stage_groups
;
stage_group_id
++
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
int
idx
=
(
pipe
>=
stages
&&
stage_group_id
==
max_num_stage_groups
-
1
)
?
(
pipe
-
stages
)
:
(
pipe
+
stage_group_id
*
stages
);
fetch_to_registers
(
k
+
1
,
pipe
%
stages
,
idx
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
int
idx
=
(
pipe
>=
1
&&
stage_group_id
==
max_num_stage_groups
-
1
)
?
(
pipe
-
1
)
:
(
pipe
+
(
stage_group_id
+
1
)
*
stages
-
1
);
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
,
idx
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_remaining_load_count_in_slice
=
0
;
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
a_gl_rd_col
+=
a_gl_rd_delta_o
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
if
constexpr
(
has_act_order
)
{
slice_k_start
+=
tb_k
*
stages
;
if
(
slice_k_start
<
prob_k
)
{
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
if
(
slice_iters
==
0
)
{
break
;
}
}
...
...
@@ -1802,7 +1796,8 @@ __global__ void Marlin(
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
...
...
@@ -1812,7 +1807,8 @@ __global__ void Marlin(
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -1836,7 +1832,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
8
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
...
...
@@ -1877,15 +1874,30 @@ __global__ void Marlin(
if
(
last
||
use_atomic_add
)
// only the last block in a slice actually writes the result
write_result
();
i
f
(
slice_row
)
a_remaining_load_count_in_slice
=
stages
;
i
nt
old_
slice_row
=
slice_row
;
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
is_first_matmul_in_slice
=
true
;
init_slice
();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if
(
slice_col
==
0
||
old_slice_row
||
prob_k
>
thread_k_blocks
*
16
*
stages
*
max_num_stage_groups
)
{
should_load_a
=
true
;
}
else
{
should_load_a
=
false
;
}
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd_col
=
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
...
...
@@ -1900,12 +1912,10 @@ __global__ void Marlin(
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
4c676e3d
...
...
@@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
...
...
@@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
bool
m_block_size_8
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_m
=
thread_m_blocks
*
16
;
int
tb_m
=
thread_m_blocks
*
(
m_block_size_8
?
8
:
16
)
;
// shm size for block_sorted_ids/block_topk_weights
// shm size for
block_sorted_ids/rd_
block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int
sh_block_meta_size
=
tb_m
*
4
*
2
;
int
sh_block_meta_size
=
tb_m
*
4
;
int
sh_a_size
=
pipe_stages
*
(
tb_m
*
tb_k
)
*
2
;
int
sh_b_size
=
pipe_stages
*
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
int
sh_red_size
=
tb_m
*
(
tb_n
+
8
)
*
2
;
int
sh_s_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
...
...
@@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size
=
sh_s_size
/
2
;
}
int
total_size
=
sh_
a
_size
+
sh_
b
_size
+
sh_
s
_size
+
sh_
zp
_size
+
sh_g_idx_size
+
sh_block_meta_size
;
int
total_size
=
max
(
sh_
b
_size
,
sh_
red
_size
)
+
sh_
a
_size
+
sh_
s
_size
+
sh_zp_size
+
sh_g_idx_size
+
sh_block_meta_size
;
return
total_size
;
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
bool
m_block_size_8
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
...
...
@@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
return
cache_size
<=
max_shared_mem
;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template
<
typename
scalar_t
>
MarlinFuncPtr
get_marlin_kernel
(
const
vllm
::
ScalarType
q_type
,
...
...
@@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto
kernel
=
MarlinDefault
;
if
(
false
)
{
}
GPTQ_GET_IF_M1
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_GET_IF_M1
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_GET_IF_M234
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_GET_IF_M234
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_GET_IF_M1
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_GET_IF_M1
(
vllm
::
kU8B128
,
8
,
4
,
128
)
COMMON_GET_IF
(
vllm
::
kU4
)
COMMON_GET_IF
(
vllm
::
kU4B8
)
COMMON_GET_IF
(
vllm
::
kU8B128
)
GPTQ_GET_IF_M234
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_GET_IF_M234
(
vllm
::
kU8B128
,
8
,
4
,
128
)
BIGGROUP_GET_IF
(
vllm
::
kFE4M3fn
)
AWQ_GET_IF_M1
(
vllm
::
kU4
,
8
,
8
,
256
)
AWQ_GET_IF_M1
(
vllm
::
kU4
,
8
,
4
,
128
)
FP4_GET_IF
(
vllm
::
kFE2M1f
)
A
WQ
_GET_IF
_M234
(
vllm
::
kU4
,
16
,
4
,
256
)
A
WQ
_GET_IF
_M234
(
vllm
::
kU
4
,
8
,
4
,
128
)
A
CT
_GET_IF
(
vllm
::
kU4
B8
)
A
CT
_GET_IF
(
vllm
::
kU
8B
128
)
return
kernel
;
}
...
...
@@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for
(
int
i
=
0
;
i
<
thread_configs_size
;
i
++
)
{
thread_config_t
th_config
=
thread_configs
[
i
];
if
(
!
is_valid_config
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
if
(
!
is_valid_config
(
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
continue
;
}
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
int
group_blocks
=
0
;
if
(
!
has_act_order
)
{
group_blocks
=
group_size
==
-
1
?
-
1
:
group_size
/
16
;
group_blocks
=
group_size
==
-
1
?
-
1
:
(
group_size
/
16
)
;
}
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
...
...
@@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
void
*
s2
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
void
*
sorted_token_ids
,
void
*
expert_ids
,
void
*
num_tokens_past_padded
,
void
*
topk_weights
,
int
moe_block_size
,
int
top_k
,
bool
mul_topk_weights
,
bool
is_ep
,
...
...
@@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
||
q_type
==
vllm
::
kFE2M1f
,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = "
,
q_type
.
str
());
}
...
...
@@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
uint16_t
*
s2_ptr
=
(
const
uint16_t
*
)
s2
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
...
...
@@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
)
,
"Invalid thread config: thread_m_blocks = "
,
th
re
a
d_m
_blocks
,
", thread_k
= "
,
thread_
tfg
.
thread_k
,
", thread_
n
= "
,
thread_tfg
.
thread_
n
,
",
num_
thread
s
= "
,
thread_tfg
.
num_
thread
s
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
",
"
,
prob_
n
,
"
] and num_bits = "
,
num_bits
,
",
g
ro
up_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
",
has_zp = "
,
has_zp
,
",
i
s_zp
_float
= "
,
i
s_zp
_float
,
", max_shared_mem = "
,
max_shared_mem
);
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_sha
red_m
em
)
,
"Invalid thread config: thread_m_blocks
= "
,
thread_
m_blocks
,
", thread_
k
= "
,
thread_tfg
.
thread_
k
,
", thread
_n
= "
,
thread_tfg
.
thread
_n
,
", num_threads = "
,
thread_tfg
.
num_threads
,
" for MKN = [
"
,
prob_
m
,
"
, "
,
prob_k
,
",
"
,
p
ro
b_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
",
is_k_full = "
,
is_k_full
,
",
ha
s_zp = "
,
ha
s_zp
,
", is_zp_float = "
,
is_zp_float
,
", max_shared_mem = "
,
max_shared_mem
);
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
q_type
,
thread_m_blocks
,
thread_n_blocks
,
thread_k_blocks
,
m_block_size_8
,
...
...
@@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel
<<<
blocks
,
num_threads
,
max_shared_mem
,
stream
>>>
(
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
zp_ptr
,
g_idx_ptr
,
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
s2_ptr
,
zp_ptr
,
g_idx_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_past_padded_ptr
,
topk_weights_ptr
,
top_k
,
mul_topk_weights
,
is_ep
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
);
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
,
max_shared_mem
);
// clang-format on
}
...
...
@@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch
::
Tensor
moe_wna16_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
const
&
c_or_none
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
global_scale_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
g_idx_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
...
...
@@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
}
}
torch
::
Tensor
global_scale
;
if
(
global_scale_or_none
.
has_value
())
{
global_scale
=
global_scale_or_none
.
value
();
TORCH_CHECK
(
b_q_type
==
vllm
::
kFE2M1f
,
"global_scale can only be used for float4_e2m1f."
);
}
else
{
global_scale
=
torch
::
empty
({
0
},
options
);
TORCH_CHECK
(
!
(
b_q_type
==
vllm
::
kFE2M1f
),
"the global_scale parameter must be passed for float4_e2m1f."
);
}
torch
::
Tensor
b_zeros
;
if
(
b_zeros_or_none
.
has_value
())
{
b_zeros
=
b_zeros_or_none
.
value
();
...
...
@@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm(
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
b_q_type
==
vllm
::
kU4
||
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4
or u8
when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
.
str
());
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
||
b_q_type
==
vllm
::
kFE4M3fn
||
b_q_type
==
vllm
::
kFE2M1f
,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = "
,
b_q_type
.
str
());
}
if
(
has_zp
&&
is_zp_float
)
{
...
...
@@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Half
>
();
}
MARLIN_NAMESPACE_NAME
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
...
...
@@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
BFloat16
>
();
}
MARLIN_NAMESPACE_NAME
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
topk_weights
.
data_ptr
(),
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
size_m
,
size_n
,
size_k
,
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
4c676e3d
...
...
@@ -399,7 +399,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
if
(
use_global_memory
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
...
...
@@ -424,7 +424,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
}
else
if
(
use_i16
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// set dynamic shared mem
auto
kernel
=
...
...
@@ -439,7 +439,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids
.
numel
());
});
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
...
...
@@ -464,7 +464,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK
(
num_experts
==
256
,
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `cumsum` tensors
auto
options_int
=
...
...
csrc/moe/moe_ops.h
View file @
4c676e3d
...
...
@@ -30,6 +30,12 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
);
#endif
bool
moe_permute_unpermute_supported
();
void
shuffle_rows
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
);
std
::
vector
<
torch
::
Tensor
>
moe_fused_gate
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
bias
,
...
...
@@ -37,4 +43,4 @@ std::vector<torch::Tensor> moe_fused_gate(
int64_t
topk_group
,
int64_t
topk
,
int64_t
n_share_experts_fusion
,
double
routed_scaling_factor
);
\ No newline at end of file
double
routed_scaling_factor
);
csrc/moe/moe_permute_unpermute_op.cu
0 → 100644
View file @
4c676e3d
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
void
moe_permute
(
const
torch
::
Tensor
&
input
,
// [n_token, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indicies
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
// [topk * n_token/align_block_size_m, hidden]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
torch
::
Tensor
&
m_indices
)
{
// [align_expand_m]
TORCH_CHECK
(
topk_weights
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"topk_weights must be float32"
);
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
token_expert_indicies
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"token_expert_indicies must be int32"
);
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"src_row_id2dst_row_id_map must be int32"
);
TORCH_CHECK
(
expert_first_token_offset
.
size
(
0
)
==
n_local_expert
+
1
,
"expert_first_token_offset shape != n_local_expert+1"
)
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
token_expert_indicies
.
sizes
(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
align_block_size
.
has_value
()
?
align_block_size
.
value
()
:
-
1
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
long
sorter_size
=
CubKeyValueSorter
::
getWorkspaceSize
(
n_token
*
topk
,
n_expert
);
auto
sort_workspace
=
torch
::
empty
(
{
sorter_size
},
torch
::
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
dst_row_id2src_row_id_map
=
torch
::
empty_like
(
src_row_id2dst_row_id_map
);
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
CubKeyValueSorter
sorter
{};
int64_t
*
valid_num_ptr
=
nullptr
;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if
(
expert_map
.
has_value
())
{
const
int
*
expert_map_ptr
=
get_ptr
<
int
>
(
expert_map
.
value
());
valid_num_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
topk_ids
),
n_token
*
topk
,
expert_map_ptr
,
n_expert
,
stream
);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert
(
get_ptr
<
int
>
(
topk_ids
),
get_ptr
<
int
>
(
token_expert_indicies
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
if
(
align_block_size
.
has_value
())
{
// update align_expert_first_token_offset
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
void
moe_unpermute
(
const
torch
::
Tensor
&
permuted_hidden_states
,
// [n_token * topk, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
const
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
const
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert+1]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
torch
::
Tensor
&
hidden_states
// [n_token, hidden]
)
{
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
topk_ids
.
sizes
(),
"topk_ids shape must be same as src_row_id2dst_row_id_map"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
permuted_hidden_states
.
scalar_type
()
==
hidden_states
.
scalar_type
(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
hidden_states
.
size
(
0
);
auto
n_hidden
=
hidden_states
.
size
(
1
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int64_t
*
valid_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
MOE_DISPATCH
(
hidden_states
.
scalar_type
(),
[
&
]
{
finalizeMoeRoutingKernelLauncher
<
scalar_t
,
scalar_t
>
(
get_ptr
<
scalar_t
>
(
permuted_hidden_states
),
get_ptr
<
scalar_t
>
(
hidden_states
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int
>
(
topk_ids
),
n_token
,
n_hidden
,
topk
,
valid_ptr
,
stream
);
});
}
template
<
typename
T
>
__global__
void
shuffleInputRowsKernel
(
const
T
*
input
,
const
int32_t
*
dst2src_map
,
T
*
output
,
int64_t
num_src_rows
,
int64_t
num_dst_rows
,
int64_t
num_cols
)
{
int64_t
dest_row_idx
=
blockIdx
.
x
;
int64_t
const
source_row_idx
=
dst2src_map
[
dest_row_idx
];
if
(
blockIdx
.
x
<
num_dst_rows
)
{
// Load 128-bits per thread
constexpr
int64_t
ELEM_PER_THREAD
=
128
/
sizeof
(
T
)
/
8
;
using
DataElem
=
cutlass
::
Array
<
T
,
ELEM_PER_THREAD
>
;
// Duplicate and permute rows
auto
const
*
source_row_ptr
=
reinterpret_cast
<
DataElem
const
*>
(
input
+
source_row_idx
*
num_cols
);
auto
*
dest_row_ptr
=
reinterpret_cast
<
DataElem
*>
(
output
+
dest_row_idx
*
num_cols
);
int64_t
const
start_offset
=
threadIdx
.
x
;
int64_t
const
stride
=
blockDim
.
x
;
int64_t
const
num_elems_in_col
=
num_cols
/
ELEM_PER_THREAD
;
for
(
int
elem_index
=
start_offset
;
elem_index
<
num_elems_in_col
;
elem_index
+=
stride
)
{
dest_row_ptr
[
elem_index
]
=
source_row_ptr
[
elem_index
];
}
}
}
void
shuffle_rows
(
const
torch
::
Tensor
&
input_tensor
,
const
torch
::
Tensor
&
dst2src_map
,
torch
::
Tensor
&
output_tensor
)
{
TORCH_CHECK
(
input_tensor
.
scalar_type
()
==
output_tensor
.
scalar_type
(),
"Input and output tensors must have the same data type"
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
int64_t
const
blocks
=
output_tensor
.
size
(
0
);
int64_t
const
threads
=
256
;
int64_t
const
num_dest_rows
=
output_tensor
.
size
(
0
);
int64_t
const
num_src_rows
=
input_tensor
.
size
(
0
);
int64_t
const
num_cols
=
input_tensor
.
size
(
1
);
TORCH_CHECK
(
!
(
num_cols
%
(
128
/
sizeof
(
input_tensor
.
scalar_type
())
/
8
)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8"
);
MOE_DISPATCH
(
input_tensor
.
scalar_type
(),
[
&
]
{
shuffleInputRowsKernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
reinterpret_cast
<
scalar_t
*>
(
input_tensor
.
data_ptr
()),
dst2src_map
.
data_ptr
<
int32_t
>
(),
reinterpret_cast
<
scalar_t
*>
(
output_tensor
.
data_ptr
()),
num_src_rows
,
num_dest_rows
,
num_cols
);
});
}
#else
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
token_expert_indicies
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_unpermute is not supported on CUDA < 12.0"
);
}
void
moe_unpermute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
token_expert_indicies
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_unpermute is not supported on CUDA < 12.0"
);
}
#endif
bool
moe_permute_unpermute_supported
()
{
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
return
true
;
#else
return
false
;
#endif
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"moe_permute"
,
&
moe_permute
);
m
.
impl
(
"moe_unpermute"
,
&
moe_unpermute
);
}
csrc/moe/moe_wna16_utils.h
View file @
4c676e3d
...
...
@@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) {
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
q
>>=
8
;
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
res
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo0
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
...
...
@@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
...
...
csrc/moe/permute_unpermute_kernels/dispatch.h
0 → 100644
View file @
4c676e3d
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template
<
at
::
ScalarType
type
>
struct
ScalarType2CudaType
;
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float
>
{
using
type
=
float
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Half
>
{
using
type
=
half
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
BFloat16
>
{
using
type
=
__nv_bfloat16
;
};
// uint8 for packed fp4
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Byte
>
{
using
type
=
uint8_t
;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float8_e5m2
>
{
using
type
=
__nv_fp8_e5m2
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float8_e4m3fn
>
{
using
type
=
__nv_fp8_e4m3
;
};
// #endif
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
0 → 100644
View file @
4c676e3d
#include "moe_permute_unpermute_kernel.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
// CubKeyValueSorter definition begin
CubKeyValueSorter
::
CubKeyValueSorter
()
:
num_experts_
(
0
),
num_bits_
(
sizeof
(
int
)
*
8
)
{}
int
CubKeyValueSorter
::
expertsToBits
(
int
num_experts
)
{
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return
static_cast
<
int
>
(
log2
(
2
*
num_experts
-
1
))
+
1
;
}
CubKeyValueSorter
::
CubKeyValueSorter
(
int
const
num_experts
)
:
num_experts_
(
num_experts
),
num_bits_
(
expertsToBits
(
num_experts
))
{}
void
CubKeyValueSorter
::
updateNumExperts
(
int
const
num_experts
)
{
num_experts_
=
num_experts
;
num_bits_
=
expertsToBits
(
num_experts
);
}
size_t
CubKeyValueSorter
::
getWorkspaceSize
(
size_t
const
num_key_value_pairs
,
int
const
num_experts
)
{
int
num_bits
=
expertsToBits
(
num_experts
);
size_t
required_storage
=
0
;
int
*
null_int
=
nullptr
;
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
required_storage
,
null_int
,
null_int
,
null_int
,
null_int
,
num_key_value_pairs
,
0
,
num_bits
);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if
(
required_storage
==
0
)
{
required_storage
=
1
;
}
return
required_storage
;
}
void
CubKeyValueSorter
::
run
(
void
*
workspace
,
size_t
const
workspace_size
,
int
const
*
keys_in
,
int
*
keys_out
,
int
const
*
values_in
,
int
*
values_out
,
size_t
const
num_key_value_pairs
,
cudaStream_t
stream
)
{
size_t
expected_ws_size
=
getWorkspaceSize
(
num_key_value_pairs
,
num_experts_
);
size_t
actual_ws_size
=
workspace_size
;
TORCH_CHECK
(
expected_ws_size
<=
workspace_size
,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem."
);
cub
::
DeviceRadixSort
::
SortPairs
(
workspace
,
actual_ws_size
,
keys_in
,
keys_out
,
values_in
,
values_out
,
num_key_value_pairs
,
0
,
num_bits_
,
stream
);
}
// CubKeyValueSorter definition end
static
inline
size_t
pad_to_multiple_of_16
(
size_t
const
&
input
)
{
static
constexpr
int
ALIGNMENT
=
16
;
return
ALIGNMENT
*
((
input
+
ALIGNMENT
-
1
)
/
ALIGNMENT
);
}
template
<
class
T
>
__device__
inline
int64_t
findTotalEltsLessThanTarget
(
T
const
*
sorted_indices
,
int64_t
const
arr_length
,
T
const
target
)
{
int64_t
low
=
0
,
high
=
arr_length
-
1
,
target_location
=
-
1
;
while
(
low
<=
high
)
{
int64_t
mid
=
(
low
+
high
)
/
2
;
if
(
sorted_indices
[
mid
]
>=
target
)
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
target_location
=
mid
;
}
}
return
target_location
+
1
;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__
void
computeExpertFirstTokenOffsetKernel
(
int
const
*
sorted_experts
,
int64_t
const
sorted_experts_len
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
)
{
// First, compute the global tid. We only need 1 thread per expert.
int
const
expert
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if
(
expert
>=
num_experts
+
1
)
{
return
;
}
expert_first_token_offset
[
expert
]
=
findTotalEltsLessThanTarget
(
sorted_experts
,
sorted_experts_len
,
expert
);
}
void
computeExpertFirstTokenOffset
(
int
const
*
sorted_indices
,
int
const
total_indices
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
)
{
int
const
num_entries
=
num_experts
+
1
;
int
const
threads
=
std
::
min
(
1024
,
num_entries
);
int
const
blocks
=
(
num_entries
+
threads
-
1
)
/
threads
;
computeExpertFirstTokenOffsetKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
sorted_indices
,
total_indices
,
num_experts
,
expert_first_token_offset
);
}
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
CubKeyValueSorter
&
sorter
,
void
*
sorter_ws
,
cudaStream_t
stream
)
{
int64_t
const
expanded_num_rows
=
static_cast
<
int64_t
>
(
k
)
*
num_rows
;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter
.
updateNumExperts
(
num_experts
);
size_t
const
sorter_ws_size_bytes
=
pad_to_multiple_of_16
(
sorter
.
getWorkspaceSize
(
expanded_num_rows
,
num_experts
));
sorter
.
run
((
void
*
)
sorter_ws
,
sorter_ws_size_bytes
,
expert_for_source_row
,
permuted_experts
,
source_rows
,
permuted_rows
,
expanded_num_rows
,
stream
);
computeExpertFirstTokenOffset
(
permuted_experts
,
expanded_num_rows
,
num_experts_per_node
,
expert_first_token_offset
,
stream
);
}
__global__
void
preprocessTopkIdKernel
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
)
{
auto
tidx
=
threadIdx
.
x
;
auto
bidx
=
blockIdx
.
x
;
auto
offset
=
bidx
*
blockDim
.
x
;
auto
bound
=
min
(
offset
+
blockDim
.
x
,
size
);
extern
__shared__
int
smem_expert_map
[];
// store expert_map in smem
for
(
int
i
=
tidx
;
i
<
num_experts
;
i
+=
blockDim
.
x
)
{
smem_expert_map
[
i
]
=
expert_map_ptr
[
i
];
}
__syncthreads
();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if
(
offset
+
tidx
<
bound
)
{
auto
topk_id
=
topk_id_ptr
[
offset
+
tidx
];
auto
local_expert_idx
=
smem_expert_map
[
topk_id
];
if
(
local_expert_idx
==
-
1
)
{
topk_id
+=
num_experts
;
}
else
{
topk_id
=
local_expert_idx
;
}
__syncwarp
();
topk_id_ptr
[
offset
+
tidx
]
=
topk_id
;
}
}
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
)
{
int
block
=
std
::
min
(
size
,
1024
);
int
grid
=
(
size
+
block
-
1
)
/
block
;
int
smem_size
=
(
num_experts
)
*
sizeof
(
int
);
preprocessTopkIdKernel
<<<
grid
,
block
,
smem_size
,
stream
>>>
(
topk_id_ptr
,
size
,
expert_map_ptr
,
num_experts
);
}
template
<
bool
ALIGN_BLOCK_SIZE
>
__global__
void
getMIndicesKernel
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
const
int
num_local_expert
,
const
int
align_block_size
)
{
int
eidx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
extern
__shared__
int64_t
smem_expert_first_token_offset
[];
for
(
int
i
=
tidx
;
i
<=
num_local_expert
;
i
+=
blockDim
.
x
)
{
smem_expert_first_token_offset
[
tidx
]
=
__ldg
(
expert_first_token_offset
+
i
);
}
__syncthreads
();
auto
last_token_offset
=
smem_expert_first_token_offset
[
eidx
+
1
];
auto
first_token_offset
=
smem_expert_first_token_offset
[
eidx
];
int
n_token_in_expert
=
last_token_offset
-
first_token_offset
;
if
constexpr
(
ALIGN_BLOCK_SIZE
)
{
n_token_in_expert
=
(
n_token_in_expert
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
// round up to ALIGN_BLOCK_SIZE
int64_t
accumulate_align_offset
=
0
;
for
(
int
i
=
1
;
i
<=
eidx
+
1
;
i
++
)
{
int
n_token
=
smem_expert_first_token_offset
[
i
]
-
smem_expert_first_token_offset
[
i
-
1
];
accumulate_align_offset
=
accumulate_align_offset
+
(
n_token
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
if
(
i
==
eidx
)
{
first_token_offset
=
accumulate_align_offset
;
}
// last block store align_expert_first_token_offset
if
(
eidx
==
num_local_expert
-
1
&&
threadIdx
.
x
==
0
)
{
align_expert_first_token_offset
[
i
]
=
accumulate_align_offset
;
}
}
}
for
(
int
idx
=
tidx
;
idx
<
n_token_in_expert
;
idx
+=
blockDim
.
x
)
{
// update m_indice with expert id
m_indices
[
first_token_offset
+
idx
]
=
eidx
;
}
}
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
)
{
int
block
=
256
;
int
grid
=
num_local_expert
;
int
smem_size
=
sizeof
(
int64_t
)
*
(
num_local_expert
+
1
);
if
(
align_block_size
==
-
1
)
{
getMIndicesKernel
<
false
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
else
{
getMIndicesKernel
<
true
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
}
#endif
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
0 → 100644
View file @
4c676e3d
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template
<
typename
T
>
inline
T
*
get_ptr
(
torch
::
Tensor
&
t
)
{
return
reinterpret_cast
<
T
*>
(
t
.
data_ptr
());
}
template
<
typename
T
>
inline
const
T
*
get_ptr
(
const
torch
::
Tensor
&
t
)
{
return
reinterpret_cast
<
const
T
*>
(
t
.
data_ptr
());
}
class
CubKeyValueSorter
{
public:
CubKeyValueSorter
();
CubKeyValueSorter
(
int
const
num_experts
);
void
updateNumExperts
(
int
const
num_experts
);
static
size_t
getWorkspaceSize
(
size_t
const
num_key_value_pairs
,
int
const
num_experts
);
void
run
(
void
*
workspace
,
size_t
const
workspace_size
,
int
const
*
keys_in
,
int
*
keys_out
,
int
const
*
values_in
,
int
*
values_out
,
size_t
const
num_key_value_pairs
,
cudaStream_t
stream
);
private:
static
int
expertsToBits
(
int
experts
);
int
num_experts_
;
int
num_bits_
;
};
void
computeExpertFirstTokenOffset
(
int
const
*
sorted_indices
,
int
const
total_indices
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
);
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
CubKeyValueSorter
&
sorter
,
void
*
sorter_ws
,
cudaStream_t
stream
);
template
<
typename
T
>
void
expandInputRowsKernelLauncher
(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
const
float
*
unpermuted_scales
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int64_t
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template
<
typename
T
,
typename
OutputType
,
bool
CHECK_SKIPPED
>
__global__
void
finalizeMoeRoutingKernel
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
orig_cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
);
template
<
class
T
,
class
OutputType
>
void
finalizeMoeRoutingKernelLauncher
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
num_rows
,
int64_t
const
cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
,
cudaStream_t
stream
);
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
);
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
);
#include "moe_permute_unpermute_kernel.inl"
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
0 → 100644
View file @
4c676e3d
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
int64_t align_expanded_row_accumulate = 0;
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_k_rank = expanded_source_row / num_rows;
int64_t const source_row = expanded_source_row % num_rows;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int64_t const expert_idx = expert_for_source_row[k_offset];
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
}
csrc/moe/topk_softmax_kernels.cu
View file @
4c676e3d
...
...
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
template
<
int
TPB
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
template
<
int
TPB
,
typename
IndType
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
...
...
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
typename
IndType
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
int
*
indices
,
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
...
...
@@ -397,8 +405,8 @@ struct TopkConstants
};
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
typename
IndType
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
...
...
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
template
<
typename
IndType
>
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
float
*
topk_weights
,
int
*
topk_indicies
,
IndType
*
topk_indicies
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
const
int
num_tokens
,
...
...
@@ -493,14 +502,44 @@ void topk_softmax(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
else
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
else
{
assert
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int64
);
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int64_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
}
csrc/moe/torch_bindings.cpp
View file @
4c676e3d
...
...
@@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m
.
def
(
"moe_sum(Tensor
!
input, Tensor output) -> ()"
);
m
.
def
(
"moe_sum(Tensor input, Tensor
!
output) -> ()"
);
m
.
impl
(
"moe_sum"
,
torch
::
kCUDA
,
&
moe_sum
);
// Aligning the number of tokens to be processed by each expert such
...
...
@@ -50,7 +50,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
...
...
@@ -59,8 +60,38 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
def
(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"m_indices)->()"
);
// conditionally compiled so impl registration is in source file
m
.
def
(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()"
);
m
.
def
(
"moe_permute_unpermute_supported() -> bool"
);
m
.
impl
(
"moe_permute_unpermute_supported"
,
&
moe_permute_unpermute_supported
);
// Row shuffle for MoE
m
.
def
(
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
"output_tensor) -> ()"
);
m
.
impl
(
"shuffle_rows"
,
torch
::
kCUDA
,
&
shuffle_rows
);
#endif
}
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment