Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aae74ef9
Unverified
Commit
aae74ef9
authored
Aug 21, 2024
by
Michael Goin
Committed by
GitHub
Aug 22, 2024
Browse files
Revert "[Kernel] Expand MoE weight loading + Add Fused Marlin MoE Kernel (#7527)" (#7764)
parent
cde9183b
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
84 additions
and
2374 deletions
+84
-2374
CMakeLists.txt
CMakeLists.txt
+1
-2
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+0
-1740
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+0
-12
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+0
-9
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+0
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-14
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+8
-6
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+18
-116
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+36
-170
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+0
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-283
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+18
-11
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-2
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+0
-1
No files found.
CMakeLists.txt
View file @
aae74ef9
...
@@ -286,8 +286,7 @@ define_gpu_extension_target(
...
@@ -286,8 +286,7 @@ define_gpu_extension_target(
set
(
VLLM_MOE_EXT_SRC
set
(
VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu"
)
"csrc/moe/marlin_moe_ops.cu"
)
define_gpu_extension_target
(
define_gpu_extension_target
(
_moe_C
_moe_C
...
...
csrc/moe/marlin_moe_ops.cu
deleted
100644 → 0
View file @
cde9183b
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin_moe
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Given 2 floats multiply by 2 scales (halves)
__device__
inline
void
scale_float
(
float
*
c
,
FragS
&
s
)
{
__half
*
s_ptr
=
reinterpret_cast
<
__half
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
__half2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
__half2float
(
s_ptr
[
1
]));
}
// Same as above, but for act_order (each K is multiplied individually)
__device__
inline
void
scale4
(
FragB
&
frag_b
,
FragS
&
frag_s_1
,
FragS
&
frag_s_2
,
FragS
&
frag_s_3
,
FragS
&
frag_s_4
,
int
i
)
{
__half2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_2
)[
i
];
__half2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
blockDim
.
x
;
int
rest
=
size_k
%
blockDim
.
x
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
blockDim
.
x
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
int
expert_id
=
threadIdx
.
x
;
int
num_experts
=
blockDim
.
x
;
int
occurrences
=
0
;
for
(
int
i
=
0
;
i
<
topk_length
;
++
i
)
{
occurrences
+=
(
topk_ids
[
i
]
==
expert_id
);
}
expert_offsets
[
expert_id
+
1
]
=
occurrences
;
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tot_offset
+=
ceildiv
(
expert_offsets
[
i
+
1
],
block_size
)
*
block_size
;
expert_offsets
[
i
+
1
]
=
tot_offset
;
}
}
__syncthreads
();
}
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__device__
inline
void
MarlinMoESingle
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int
*
__restrict__
sorted_ids
,
// int32 sorted ids of experts
const
float
*
__restrict__
topk_weights
,
// float topk weights
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
expert_offsets
,
int
num_groups
,
// number of scale groups per output channel
int
expert_idx
,
// idx of current expert
int
num_experts
,
// number of experts
int
topk
,
// topk parameter of moe
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
tot_m
,
// total number of rows in A and C
int
*
locks
,
// extra global storage for barrier synchronization
bool
replicate_input
,
// do we use the same input for each expert?
bool
apply_weights
,
// apply weights to output
int
current_m_block
// current m block to start kernel computation from
)
{
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
ceildiv
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
if
constexpr
(
!
has_act_order
&&
group_blocks
!=
-
1
)
{
if
(
group_blocks
>=
thread_k_blocks
)
{
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters
=
(
group_blocks
/
thread_k_blocks
)
*
ceildiv
(
iters
,
(
group_blocks
/
thread_k_blocks
));
}
}
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
sorted_ids
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
ceildiv
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
ceildiv
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
sorted_ids
+=
16
*
thread_m_blocks
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
ceildiv
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
32
;
constexpr
int
b_sh_stride
=
32
*
thread_n_blocks
/
4
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride
);
constexpr
int
b_sh_wr_delta
=
threads
;
constexpr
int
b_sh_rd_delta
=
threads
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
constexpr
int
sorted_sh_stride
=
threads
;
constexpr
int
sorted_gl_stride
=
threads
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
;
int
b_sh_rd
=
threadIdx
.
x
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
;
if
constexpr
(
group_blocks
==
-
1
||
group_blocks
==
0
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
int
shs_size
;
if
constexpr
(
has_act_order
)
shs_size
=
sh_max_num_groups
*
s_sh_stride
+
threads
;
else
shs_size
=
group_blocks
>
0
?
stages
*
s_sh_stage
:
threads
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int
*
sh_sorted
=
(
int
*
)(
sh_s
+
shs_size
);
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_sh_wr_delta
*
i
+
a_sh_wr
;
int
row
=
a_idx
/
a_gl_rd_delta_o
;
if
(
row
>=
prob_m
)
{
a_sh_wr_pred
[
i
]
=
false
;
}
else
{
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
}
}
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
;
int
row
=
a_idx
/
a_gl_stride
;
int
sorted_row
=
replicate_input
?
sorted_ids
[
row
]
/
topk
:
sorted_ids
[
row
];
int
new_idx
=
sorted_row
*
a_gl_stride
+
a_idx
%
a_gl_stride
;
if
(
sorted_row
<
tot_m
*
(
replicate_input
?
1
:
topk
)
&&
new_idx
<
a_gl_stride
*
tot_m
*
(
replicate_input
?
1
:
topk
))
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
new_idx
],
a_sh_wr_pred
[
i
]);
}
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
],
B_ptr
[
i
]);
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
if
constexpr
(
has_act_order
)
{
// Fetch g_idx thread-block portion
int
full_pipe
=
a_off
;
int
cur_k
=
slice_k_start_shared_fetch
+
tb_k
*
full_pipe
;
if
(
cur_k
<
prob_k
&&
cur_k
<
slice_k_finish
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
const
*
cur_g_idx_stage_ptr
=
reinterpret_cast
<
int4
const
*>
(
&
g_idx
[
cur_k
]);
if
(
threadIdx
.
x
<
g_idx_stage
)
{
cp_async4_pred
(
&
sh_g_idx_stage
[
threadIdx
.
x
],
&
cur_g_idx_stage_ptr
[
threadIdx
.
x
]);
}
}
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// TODO we are currently hitting illegal memory accesses when fetching
// sorted_ids to shared data: fix this
auto
fetch_sorted_ids_to_shared
=
[
&
]()
{
const
int
mpt
=
ceildiv
(
prob_m
,
threads
);
for
(
int
i
=
0
;
i
<
mpt
;
i
++
)
{
if
((
i
*
sorted_gl_stride
)
+
threadIdx
.
x
<
prob_m
)
{
sh_sorted
[(
i
*
sorted_sh_stride
)
+
threadIdx
.
x
]
=
sorted_ids
[(
i
*
sorted_gl_stride
)
+
threadIdx
.
x
];
}
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
frag_b_quant
[
k
%
2
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
]);
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
group_id_1
=
sh_g_idx_int_ptr
[
0
];
int
group_id_2
=
sh_g_idx_int_ptr
[
tb_k
-
1
];
is_same_group
[
pipe
]
=
group_id_1
==
group_id_2
;
same_group_id
[
pipe
]
=
group_id_1
;
};
auto
fetch_scales_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
!
has_act_order
)
{
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
}
return
;
}
// Act-order case
// Determine K of the "current" thread-block
int
cur_k
=
slice_k_start
+
tb_k
*
full_pipe
;
if
(
cur_k
>=
prob_k
||
cur_k
>=
slice_k_finish
)
{
return
;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k
=
0
;
// Progress to current iteration
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
int
warp_row
=
warp_id
/
n_warps
;
int
warp_col
=
warp_id
%
n_warps
;
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
/*slice_n_offset +*/
(
act_s_col_warp_stride
*
warp_col
)
+
(
th_id
/
4
)
*
act_s_col_stride
;
if
(
is_same_group
[
pipe
])
{
if
(
k
%
2
==
0
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
sh_s
[(
same_group_id
[
pipe
]
-
sh_first_group_id
)
*
s_sh_stride
+
s_col_shift
];
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[(
k
-
1
)
%
2
][
0
][
0
])));
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])));
}
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
constexpr
int
k_frag_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
// Tensor core offsets per thread
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
actual_k
=
cur_k
+
k_frag_offsets
[
i
];
int
group_id
=
sh_g_idx_int_ptr
[
actual_k
];
int
rel_group_id
=
group_id
-
sh_first_group_id
;
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
sh_s
[
rel_group_id
*
s_sh_stride
+
s_col_shift
];
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
FragB
frag_b0
=
dequant
(
b_quant
);
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
}
FragB
frag_b1
=
dequant
(
b_quant_shift
);
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
constexpr
int
red_sh_stride
=
b_sh_stride
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
int
c_idx
=
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
);
int
sorted_row
=
sorted_ids
[
c_idx
/
c_gl_stride
];
int
new_idx
=
sorted_row
*
c_gl_stride
+
c_idx
%
c_gl_stride
;
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
new_idx
],
sorted_row
<
tot_m
*
topk
&&
(
8
*
(
i
/
2
)
+
row
<
prob_m
&&
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
sorted_ids
[
8
*
(
i
/
2
)
+
row
]
<
tot_m
*
topk
)));
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
8
*
(
i
/
2
)
+
row
<
prob_m
&&
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
sorted_ids
[
8
*
(
i
/
2
)
+
row
]
<
tot_m
*
topk
))
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
__half2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[
j
]
=
__float2half
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
int
c_idx
=
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
);
int
row
=
sorted_ids
[
c_idx
/
c_gl_stride
];
if
(
row
<
tot_m
*
topk
)
{
int
new_idx
=
row
*
c_gl_stride
+
c_idx
%
c_gl_stride
;
C
[
new_idx
]
=
c
;
}
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
// For per-column quantization we finally apply the scale here
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
((
half2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
ceildiv
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
int
row
=
sorted_ids
[
c_gl_wr
/
c_gl_stride
];
if
(
row
<
tot_m
*
topk
)
{
int
off
=
row
*
c_gl_stride
+
c_gl_wr
%
c_gl_stride
;
if
(
!
apply_weights
)
{
C
[
off
]
=
sh
[
c_sh_rd
];
}
else
{
__half
*
ctrg
=
reinterpret_cast
<
__half
*>
(
&
C
[
off
]);
__half
*
csrc
=
reinterpret_cast
<
__half
*>
(
&
sh
[
c_sh_rd
]);
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
ctrg
[
j
]
=
__float2half
(
topk_weights
[
row
]
*
__half2float
(
csrc
[
j
]));
}
}
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
// TODO re-enable after fixing this function
// fetch_sorted_ids_to_shared();
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
if
(
has_act_order
&&
i
==
0
)
{
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
(
last
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
(
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
if
constexpr
(
has_act_order
)
{
slice_k_start
=
tb_k
*
slice_row
;
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
}
}
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
MarlinMoE
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int
*
__restrict__
sorted_ids_base
,
// int32 sorted ids of experts
const
float
*
__restrict__
topk_weights
,
// float topk weights
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
expert_offsets
,
int
num_groups
,
// number of scale groups per output channel
int
expert_idx
,
// idx of current expert
int
num_experts
,
// number of experts
int
topk
,
// topk parameter of moe
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
tot_m
,
// total number of rows in A and C
int
*
locks
,
// extra global storage for barrier synchronization
bool
replicate_input
,
// do we use the same input for each expert?
bool
apply_weights
,
// apply weights to output
int
current_m_block
,
// current m block to start kernel computation from
int
max_par
// maximum parallelism
)
{
int
m_block_ctr
=
current_m_block
;
const
int
*
sorted_ids_expert
=
sorted_ids_base
+
expert_offsets
[
expert_idx
]
+
m_block_ctr
*
4
*
max_par
;
int
tot_its
=
expert_offsets
[
expert_idx
+
1
]
-
expert_offsets
[
expert_idx
];
if
(
tot_its
==
0
)
{
return
;
}
int
tot_m_blocks
=
ceildiv
(
tot_its
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_its
;
if
(
m_block_ctr
>=
tot_m_blocks
)
{
return
;
}
int
max_block
=
tot_m_blocks
-
m_block_ctr
;
prob_m
=
tot_its
-
16
*
m_block_ctr
;
int
par
=
1
;
if
(
max_block
>
4
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
max_block
-
pad
)
/
64
;
par
=
min
((
16
*
max_block
-
pad
)
/
64
,
max_par
);
prob_m
=
64
*
par
;
m_block_ctr
+=
4
*
(
par
-
1
);
max_block
=
4
;
}
if
(
max_block
==
1
)
{
MarlinMoESingle
<
threads
,
1
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
else
if
(
max_block
==
2
)
{
MarlinMoESingle
<
threads
,
2
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
else
if
(
max_block
==
3
)
{
MarlinMoESingle
<
threads
,
3
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
else
{
MarlinMoESingle
<
threads
,
4
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
sorted_ids_expert
,
topk_weights
,
scales_ptr
,
g_idx
,
expert_offsets
,
num_groups
,
expert_idx
,
num_experts
,
topk
,
prob_m
,
prob_n
,
prob_k
,
tot_m
,
locks
,
replicate_input
,
apply_weights
,
current_m_block
);
}
}
#else
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
MarlinMoE
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int
*
__restrict__
sorted_ids
,
// int32 sorted ids of experts
const
float
*
__restrict__
topk_weights
,
// float topk weights
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
expert_offsets
,
int
num_groups
,
// number of scale groups per output channel
int
expert_idx
,
// idx of current expert
int
num_experts
,
// number of experts
int
topk
,
// topk parameter of moe
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
tot_m
,
// total number of rows in A and C
int
*
locks
,
// extra global storage for barrier synchronization
bool
replicate_input
,
// do we use the same input for each expert?
bool
apply_weights
,
// apply weights to output
int
current_m_block
,
// current m block to start kernel computation from
int
max_par
// maximum parallelism
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
#endif
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
const
int
USER_THREADS
=
256
;
// Note: This is only used with user-provided thread_k/n
const
int
STAGES
=
4
;
// 4 pipeline stages fit into shared memory
// const int SHARED_MEM =
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_k
=
64
;
#define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MarlinMoE<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
// Default
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
};
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if
(
th_config
.
thread_k
!=
128
&&
th_config
.
thread_k
!=
64
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
return
true
;
}
thread_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
return
th_config
;
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
return
th_config
;
}
}
}
return
thread_config_t
{
-
1
,
-
1
,
-
1
};
}
#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
void
marlin_mm_moe_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
topk_ids
,
const
void
*
s
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
// Set thread config
thread_config_t
th_config
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
th_config
=
thread_config_t
{
thread_k
,
thread_n
,
USER_THREADS
};
}
else
{
// Auto config
th_config
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
);
}
TORCH_CHECK
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
),
"Invalid thread config: thread_k = "
+
str
(
th_config
.
thread_k
)
+
", thread_n = "
+
str
(
th_config
.
thread_n
)
+
", num_threads = "
+
str
(
th_config
.
num_threads
)
+
" for MKN = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_k
)
+
", "
+
str
(
prob_n
)
+
"]"
);
int
num_threads
=
th_config
.
num_threads
;
thread_k
=
th_config
.
thread_k
;
thread_n
=
th_config
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
group_size
!=
-
1
);
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
else
{
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
}
else
{
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
int
tot_m
=
prob_m
;
const
int
*
topk_ids_ptr
=
(
const
int
*
)
topk_ids
;
int
*
expert_offsets_ptr
=
(
int
*
)
expert_offsets
;
compute_expert_offsets
<<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids_ptr
,
expert_offsets_ptr
,
tot_m
*
topk
,
moe_block_size
);
bool
do_permute_a
=
has_act_order
;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
}
for
(
int
expert_idx
=
0
;
expert_idx
<
num_experts
;
++
expert_idx
)
{
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
+
(
prob_n
*
prob_k
/
32
)
*
expert_idx
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
(((
group_size
==
-
1
||
group_size
==
0
)
?
1
:
prob_k
/
group_size
)
*
prob_n
/
8
)
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
int
*
locks
=
(
int
*
)
workspace
;
if
(
do_permute_a
)
{
// Permute A columns
int
topk_rows
=
replicate_input
?
tot_m
:
tot_m
*
topk
;
int
block_rows
=
ceildiv
(
topk_rows
,
blocks
);
permute_cols_kernel
<<<
blocks
,
num_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
topk_rows
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
int
max_m_blocks
=
ceildiv
(
tot_m
,
16
);
for
(
int
m_block
=
0
;
m_block
<
max_m_blocks
;
m_block
+=
16
)
{
// Define kernel configurations
// make it max possible value
int
thread_m_blocks
=
4
;
if
(
false
)
{
}
CALL_IF_MOE
(
16
,
4
,
256
)
CALL_IF_MOE
(
8
,
8
,
256
)
CALL_IF_MOE
(
8
,
4
,
128
)
CALL_IF_MOE
(
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
}
}
}
}
// namespace marlin_moe
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
int
max_par
=
4
;
int
dev
=
a
.
get_device
();
auto
options_dtype
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
zeros
({
size_m
,
topk
,
size_n
},
options_dtype
);
torch
::
Tensor
a_tmp
=
replicate_input
?
torch
::
zeros
({
size_m
,
size_k
},
options_dtype
)
:
torch
::
zeros
({
size_m
,
topk
,
size_k
},
options_dtype
);
torch
::
Tensor
expert_offsets
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
1
)
!=
0
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
3
,
"b_scales rank = "
,
b_rank
,
" is not 3"
);
TORCH_CHECK
(
b_scales
.
size
(
2
)
==
size_n
,
"b_scales dim 2 = "
,
b_scales
.
size
(
2
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
1
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by num_groups = "
,
num_groups
);
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
0
;
}
}
else
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
-
1
;
}
}
marlin_moe
::
marlin_mm_moe_f16i4
(
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
}
\ No newline at end of file
csrc/moe/marlin_moe_ops.h
deleted
100644 → 0
View file @
cde9183b
#pragma once
#include <torch/all.h>
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
\ No newline at end of file
csrc/moe/torch_bindings.cpp
View file @
aae74ef9
#include "core/registration.h"
#include "core/registration.h"
#include "moe_ops.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
// Apply topk softmax to the gating outputs.
...
@@ -8,14 +7,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -8,14 +7,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
}
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
tests/weight_loading/models.txt
View file @
aae74ef9
...
@@ -13,7 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
...
@@ -13,7 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
awq, casperhansen/mixtral-instruct-awq, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
vllm/_custom_ops.py
View file @
aae74ef9
...
@@ -300,20 +300,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
...
@@ -300,20 +300,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_moe_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
num_experts
=
b_q_weight
.
shape
[
0
]
assert
size_k
%
16
==
0
output
=
torch
.
empty
((
num_experts
,
size_k
//
16
,
size_n
*
2
),
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
_C
.
gptq_marlin_repack
(
b_q_weight
[
e
],
perm
[
e
],
size_k
,
size_n
,
num_bits
)
return
output
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
aae74ef9
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoEMethodBase
)
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
]
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
]
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_
marlin_
moe
,
fused_
moe
,
fused_topk
,
fused_experts
,
fused_moe
,
fused_
topk
,
get_config_file_name
,
get_config_file_name
,
grouped_topk
)
grouped_topk
)
__all__
+=
[
__all__
+=
[
"fused_marlin_moe"
,
"fused_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_topk"
,
"fused_experts"
,
"fused_experts"
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
aae74ef9
...
@@ -323,16 +323,21 @@ def get_moe_configs(E: int, N: int,
...
@@ -323,16 +323,21 @@ def get_moe_configs(E: int, N: int,
return
None
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
def
get_default_config
(
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
)
->
Dict
[
str
,
int
]:
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
)
->
Dict
[
str
,
int
]:
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
'GROUP_SIZE_M'
:
8
}
}
if
M
<=
E
or
(
is_marlin
and
M
<=
32
)
:
if
M
<=
E
:
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_N'
:
32
,
...
@@ -342,14 +347,14 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
...
@@ -342,14 +347,14 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
def
try_get_optimal_moe_config
(
w
2
_shape
:
Tuple
[
int
,
...],
w
1
_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
w2_shape
:
Tuple
[
int
,
...]
,
dtype
:
Optional
[
str
]
,
top_k
:
int
,
M
:
int
,
dtype
:
Optional
[
str
]
,
override_config
:
Optional
[
Dict
[
str
,
M
:
int
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
):
):
if
override_config
:
if
override_config
:
config
=
override_config
config
=
override_config
else
:
else
:
...
@@ -363,8 +368,7 @@ def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
...
@@ -363,8 +368,7 @@ def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
else
:
# Else use the default config
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
is_marlin
)
return
config
return
config
...
@@ -437,108 +441,6 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -437,108 +441,6 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
rand_perm1
:
torch
.
Tensor
,
rand_perm2
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
#TODO fp8 is not implemented yet
assert
not
use_fp8
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
rand_perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
rand_perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
aae74ef9
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
enum
import
Enum
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
FusedMoeWeightScaleSupported
(
Enum
):
TENSOR
=
"tensor"
CHANNEL
=
"channel"
GROUP
=
"group"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
@
abstractmethod
@
abstractmethod
...
@@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
...
@@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
)
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_id
:
int
):
param_data
=
param
.
data
# for per tensor weight quantization
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
elif
shard_id
==
"w2"
:
param_data
[
expert_id
]
=
loaded_weight
def
_load_model_weight_or_group_weight_scale
(
self
,
shard_dim
:
int
,
expert_data
:
torch
.
Tensor
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# Load grouped weight scales for group quantization
# or model weights
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
def
_load_per_channel_weight_scale
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# for per channel weight quantization
if
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
def
_load_w13
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
# w3, up_proj: Load into second logical weight of w13.
else
:
assert
shard_id
==
"w3"
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
def
_load_w2
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data
.
copy_
(
loaded_weight
)
def
_load_single_value
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_id
:
int
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
param_data
[
expert_id
]
=
loaded_weight
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
f
"got
{
shard_id
}
."
)
WEIGHT_SCALE_SUPPORTED
=
[
# Special case for fp8 scales.
e
.
value
for
e
in
FusedMoeWeightScaleSupported
if
getattr
(
param
,
"is_fp8_scale"
,
False
):
]
self
.
_load_fp8_scale
(
param
.
data
,
loaded_weight
,
weight_name
,
# Fetch the dim to shard the parameter/loaded weight
shard_id
,
expert_id
)
# based on the shard id. This will be whatever
return
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed: whether or not the parameter is transposed on disk
# If transposed, weight is saved as [input_dim, output_dim]
# If transposed, the loaded weight will be transposed and the dim
# Otherwise, weight is saved as [output_dim, input_dim]
# to shard the loaded weight will be flipped.
# Default is not transposed/input dim is dim 1
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
input_dim
=
getattr
(
param
,
"input_dim"
,
1
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
output_dim
=
getattr
(
param
,
"output_dim"
,
0
)
if
is_transposed
:
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
shard_dim
=
~
shard_dim
# Case weight_scales
if
"weight_scale"
in
weight_name
:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
self
.
_load_per_channel_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
GROUP
.
value
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
else
:
raise
ValueError
(
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
if
"weight_shape"
in
weight_name
:
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case input scale
# Index the loaded weight for tp sharding.
if
"input_scale"
in
weight_name
:
# down_proj: "RowParallel" so tp sharding on input_dim
# Note: input_scale loading is only supported for fp8
if
shard_id
==
"w2"
:
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
shard_dim
=
input_dim
loaded_weight
).
abs
()
>
1e-5
:
shard_size
=
expert_data
.
shape
[
shard_dim
]
raise
ValueError
(
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
"input_scales of w1 and w3 of a layer "
elif
shard_id
in
(
"w1"
,
"w3"
):
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
shard_dim
=
output_dim
f
"vs.
{
loaded_weight
}
"
)
shard_size
=
expert_data
.
shape
[
output_dim
]
//
2
offset
=
shard_size
*
tp_rank
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
offset
,
shard_size
)
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case model weights
# Narrow parameter and load.
if
"weight"
in
weight_name
:
# w1, gate_proj: Load into first logical weight of w13.
self
.
_load_model_weight_or_group_weight_scale
(
if
shard_id
==
"w1"
:
shard_id
=
shard_id
,
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
shard_dim
=
shard_dim
,
expert_data
.
copy_
(
loaded_weight
)
loaded_weight
=
loaded_weight
,
# w3, up_proj: Load into second logical weight of w13.
expert_data
=
expert_data
,
elif
shard_id
==
"w3"
:
tp_rank
=
tp_rank
)
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
return
expert_data
.
copy_
(
loaded_weight
)
# w2, down_proj: Load into only logical weight of w2.
elif
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
else
:
raise
ValueError
(
f
"Expected shard_id w1,w2 or w3 but got
{
shard_id
}
"
)
@
staticmethod
@
staticmethod
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
aae74ef9
...
@@ -3,12 +3,9 @@ from typing import Any, Dict, List, Optional
...
@@ -3,12 +3,9 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
...
@@ -67,8 +64,6 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -67,8 +64,6 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
(
self
)
return
None
return
None
@
classmethod
@
classmethod
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
deleted
100644 → 0
View file @
cde9183b
import
enum
from
enum
import
Enum
from
typing
import
List
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
__all__
=
[
"CompressedTensorsMoEMethod"
]
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
num_bits
=
config
.
num_bits
self
.
packed_factor
=
32
//
config
.
num_bits
self
.
strategy
=
config
.
strategy
.
value
self
.
group_size
=
config
.
group_size
assert
config
.
symmetric
,
(
"Only symmetric quantization is supported for MoE"
)
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
self
.
strategy
})
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
packed_factor
,
2
*
intermediate_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_packed"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
//
self
.
packed_factor
,
hidden_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_packed"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
strategy
==
"channel"
:
num_groups_w2
=
num_groups_w13
=
1
self
.
group_size
=
-
1
else
:
num_groups_w2
=
intermediate_size
//
self
.
group_size
num_groups_w13
=
hidden_size
//
self
.
group_size
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w13
,
2
*
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_scale
)
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_scale
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
w2_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_shape"
,
w2_weight_shape
)
set_weight_attrs
(
w2_weight_shape
,
extra_weight_attrs
)
w13_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_shape"
,
w13_weight_shape
)
set_weight_attrs
(
w13_weight_shape
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
((
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
,
num_bits
)
return
output
size_k2
=
layer
.
w2_weight_packed
.
shape
[
2
]
size_k13
=
layer
.
w13_weight_packed
.
shape
[
2
]
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w13_weight_packed
.
shape
[
2
],
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w2_weight_packed
.
shape
[
2
],
self
.
num_bits
,
)
replace_tensor
(
"w2_weight_packed"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
layer
.
w13_weight_scale
,
size_k13
,
layer
.
w13_weight_scale
.
shape
[
2
],
self
.
group_size
,
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_scale"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
.
shape
[
1
]
*
self
.
packed_factor
,
size_k2
,
self
.
group_size
,
self
.
num_bits
,
)
replace_tensor
(
"w2_weight_scale"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_marlin_moe
)
return
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
top_k
,
renormalize
=
renormalize
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
)
vllm/model_executor/layers/quantization/fp8.py
View file @
aae74ef9
...
@@ -7,8 +7,7 @@ from torch.nn.parameter import Parameter
...
@@ -7,8 +7,7 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -319,16 +318,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -319,16 +318,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
{
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w2_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
# INPUT_SCALES
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
@@ -341,14 +343,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -341,14 +343,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_experts
,
dtype
=
torch
.
float32
),
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
else
:
else
:
layer
.
w13_input_scale
=
None
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
...
...
vllm/model_executor/model_loader/utils.py
View file @
aae74ef9
...
@@ -23,11 +23,11 @@ def get_model_architecture(
...
@@ -23,11 +23,11 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
model_config
.
quantization
!=
"fp8"
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/models/jamba.py
View file @
aae74ef9
...
@@ -920,7 +920,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -920,7 +920,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
loaded_weight
,
loaded_weight
,
name
,
weight_
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
expert_id
=
expert_id
)
break
break
...
...
vllm/model_executor/models/mixtral.py
View file @
aae74ef9
...
@@ -73,7 +73,6 @@ class MixtralMoE(nn.Module):
...
@@ -73,7 +73,6 @@ class MixtralMoE(nn.Module):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
num_experts
,
bias
=
False
,
bias
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment