Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangkx1
text-generation-inference-dcu
Commits
81a882ad
Commit
81a882ad
authored
Nov 21, 2024
by
jixx
Browse files
add tgi2.4.0
parent
9822d7f6
Changes
267
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
0 additions
and
3066 deletions
+0
-3066
server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh
server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh
+0
-77
server/marlin/marlin_kernels/gptq_marlin_repack.cu
server/marlin/marlin_kernels/gptq_marlin_repack.cu
+0
-350
server/marlin/marlin_kernels/marlin_cuda_kernel.cu
server/marlin/marlin_kernels/marlin_cuda_kernel.cu
+0
-1136
server/marlin/marlin_kernels/sparse/common/base.h
server/marlin/marlin_kernels/sparse/common/base.h
+0
-51
server/marlin/marlin_kernels/sparse/common/mem.h
server/marlin/marlin_kernels/sparse/common/mem.h
+0
-136
server/marlin/marlin_kernels/sparse/common/mma.h
server/marlin/marlin_kernels/sparse/common/mma.h
+0
-191
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
+0
-1125
No files found.
Too many changes to show.
To preserve performance only
267 of 267+
files are displayed.
Plain diff
Email patch
server/marlin/marlin_kernels/gptq_marlin_dtypes.cuh
deleted
100644 → 0
View file @
9822d7f6
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
class
ScalarType
{};
template
<
>
class
ScalarType
<
half
>
{
public:
using
scalar_t
=
half
;
using
scalar_t2
=
half2
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
static
__device__
float
inline
num2float
(
const
half
x
)
{
return
__half2float
(
x
);
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
static
__device__
half2
inline
nums2num2
(
const
half
x1
,
const
half
x2
)
{
return
__halves2half2
(
x1
,
x2
);
}
static
__host__
__device__
half
inline
float2num
(
const
float
x
)
{
return
__float2half
(
x
);
}
};
template
<
>
class
ScalarType
<
nv_bfloat16
>
{
public:
using
scalar_t
=
nv_bfloat16
;
using
scalar_t2
=
nv_bfloat162
;
using
FragA
=
Vec
<
nv_bfloat162
,
4
>
;
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
#endif
};
}
// namespace gptq_marlin
#endif
server/marlin/marlin_kernels/gptq_marlin_repack.cu
deleted
100644 → 0
View file @
9822d7f6
#include "gptq_marlin.cuh"
namespace
gptq_marlin
{
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
perm_size
=
tile_k_size
/
4
;
int4
*
sh_perm_ptr
=
sh
;
int4
*
sh_pipe_ptr
=
sh_perm_ptr
;
if
constexpr
(
has_perm
)
{
sh_pipe_ptr
+=
perm_size
;
}
constexpr
int
tile_ints
=
tile_k_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_size
/
4
;
constexpr
int
stage_k_threads
=
has_perm
?
tile_k_size
:
tile_ints
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
load_perm_to_shared
=
[
&
](
int
k_tile_id
)
{
int
first_k_int4
=
(
k_tile_id
*
tile_k_size
)
/
4
;
int4
const
*
perm_int4_ptr
=
reinterpret_cast
<
int4
const
*>
(
perm_ptr
);
if
(
threadIdx
.
x
<
perm_size
)
{
sh_perm_ptr
[
threadIdx
.
x
]
=
perm_int4_ptr
[
first_k_int4
+
threadIdx
.
x
];
}
__syncthreads
();
};
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int4
*
sh_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
if
constexpr
(
has_perm
)
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
uint32_t
const
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
int
src_k
=
sh_perm_int_ptr
[
k_id
];
int
src_k_packed
=
src_k
/
pack_factor
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[
src_k_packed
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
else
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k_packed
=
first_k
/
pack_factor
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
constexpr
int
sh_stride
=
64
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
uint32_t
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_perm_ptr
);
uint32_t
vals
[
8
];
if
constexpr
(
has_perm
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
k_idx
=
tc_row
+
tc_offsets
[
i
];
uint32_t
src_k
=
sh_perm_int_ptr
[
k_idx
];
uint32_t
src_k_pos
=
src_k
%
pack_factor
;
uint32_t
b1_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
];
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
uint32_t
b2_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
+
8
];
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
vals
[
i
]
=
b1_cur_val
;
vals
[
4
+
i
]
=
b2_cur_val
;
}
}
else
{
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b2_vals
[
tile_ints
];
#pragma unroll
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
cur_int
=
cur_elem
/
pack_factor
;
int
cur_pos
=
cur_elem
%
pack_factor
;
vals
[
i
]
=
(
b1_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
b2_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
}
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
if
constexpr
(
has_perm
)
{
load_perm_to_shared
(
k_tile_id
);
}
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
((
size_k
/
pack_factor
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", pack_factor = "
,
pack_factor
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
==
size_n
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not size_n = "
,
size_n
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
TORCH_CHECK
(
perm
.
dtype
()
==
at
::
kInt
,
"perm type is not at::kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
({
size_k
/
gptq_marlin
::
tile_size
,
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
// Detect if there is act_order
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
const
*
perm_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
perm
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
,
false
)
CALL_IF
(
4
,
true
)
CALL_IF
(
8
,
false
)
CALL_IF
(
8
,
true
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
,
", has_perm = "
,
has_perm
);
}
return
out
;
}
#endif
server/marlin/marlin_kernels/marlin_cuda_kernel.cu
deleted
100644 → 0
View file @
9822d7f6
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
s
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
ceildiv
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts in
// the middle of group.
if
(
group_blocks
!=
-
1
)
iters
=
(
group_blocks
/
thread_k_blocks
)
*
ceildiv
(
iters
,
(
group_blocks
/
thread_k_blocks
));
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
ceildiv
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
ceildiv
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
int
a_gl_stride
=
prob_k
/
8
;
// stride of the A matrix in global memory
// We typically use `constexpr` to indicate that this value is a compile-time
// constant
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between subsequent accesses within a tile
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// within a shared memory tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// overall size of a tile
constexpr
int
a_sh_wr_iters
=
ceildiv
(
a_sh_stage
,
a_sh_wr_delta
);
// number of shared write iterations for a tile
int
b_gl_stride
=
16
*
prob_n
/
32
;
constexpr
int
b_sh_stride
=
32
*
thread_n_blocks
/
4
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride
);
constexpr
int
b_sh_wr_delta
=
threads
;
constexpr
int
b_sh_rd_delta
=
threads
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_sh_stage
=
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
;
int
b_sh_rd
=
threadIdx
.
x
;
int
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_rd
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
if
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_s
=
sh_b
+
(
stages
*
b_sh_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
],
B_ptr
[
i
]);
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
// Only fetch scales if this tile starts a new group
if
(
group_blocks
!=
-
1
&&
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
s_gl_rd
+=
s_gl_rd_delta
;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
// It may seem inefficient that we reload the groups for every sub-tile;
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
frag_b_quant
[
k
%
2
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
FragB
frag_b0
=
dequant
(
b_quant
);
// If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight.
if
(
group_blocks
!=
-
1
)
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
FragB
frag_b1
=
dequant
(
b_quant_shift
);
if
(
group_blocks
!=
-
1
)
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
constexpr
int
red_sh_stride
=
b_sh_stride
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
__half2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[
j
]
=
__float2half
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
if
(
group_blocks
==
-
1
)
// for per-column quantization we finally apply the scale here
res
=
__hmul2
(
res
,
s
[
0
]);
((
half2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
ceildiv
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
zero_accums
();
wait_for_stage
();
fetch_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
};
start_pipes
();
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines have
// even length meaning that the next iteration will always start at index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
break
;
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
(
group_blocks
==
-
1
&&
last
)
{
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
cp_async_fence
();
}
thread_block_reduce
();
if
(
group_blocks
==
-
1
&&
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
start_pipes
();
}
}
}
}
#else
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
s
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
#endif
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
const
int
USER_THREADS
=
256
;
// Note: This is only used with user-provided thread_k/n
const
int
STAGES
=
4
;
// 4 pipeline stages fit into shared memory
const
int
SHARED_MEM
=
96
*
1024
;
// max shared memory on compute capability 8.6 (< 8.0)
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_k
=
64
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
static
constexpr
int
pack_factor_4bit
=
8
;
// We have 8 4-bit vals inside a 32 bit
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM); \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
// Default
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
};
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if
(
th_config
.
thread_k
!=
128
&&
th_config
.
thread_k
!=
64
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
return
true
;
}
thread_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
return
th_config
;
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
return
th_config
;
}
}
}
return
thread_config_t
{
-
1
,
-
1
,
-
1
};
}
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
void
marlin_cuda
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
groupsize
=
-
1
,
int
dev
=
0
,
cudaStream_t
stream
=
0
,
int
thread_k
=
-
1
,
int
thread_n
=
-
1
,
int
sms
=
-
1
,
int
max_par
=
16
)
{
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
ceildiv
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
// Set thread config
thread_config_t
th_config
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
th_config
=
thread_config_t
{
thread_k
,
thread_n
,
USER_THREADS
};
}
else
{
// Auto config
th_config
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
);
}
if
(
!
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
throw
std
::
runtime_error
(
"Invalid thread config: thread_k = "
+
str
(
th_config
.
thread_k
)
+
", thread_n = "
+
str
(
th_config
.
thread_n
)
+
", num_threads = "
+
str
(
th_config
.
num_threads
)
+
" for MKN = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_k
)
+
", "
+
str
(
prob_n
)
+
"]"
);
}
// Uncomment for debug
// std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
// ", thread_n = " + str(th_config.thread_n) +
// ", num_threads = " + str(th_config.num_threads) + " for
// MKN = [" + str(prob_m) +
// ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
int
num_threads
=
th_config
.
num_threads
;
thread_k
=
th_config
.
thread_k
;
thread_n
=
th_config
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
group_blocks
=
(
groupsize
==
-
1
)
?
-
1
:
groupsize
/
16
;
int
blocks
=
sms
;
if
(
prob_m
==
0
||
prob_n
==
0
||
prob_k
==
0
)
{
return
;
}
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
if
(
group_blocks
!=
-
1
)
{
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
int
*
locks
=
(
int
*
)
workspace
;
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
4
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
4
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
64
;
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
64
*
par
;
i
+=
4
*
(
par
-
1
);
thread_m_blocks
=
4
;
}
// For compilation speed, we only define the kernel configurations that have
// seemed useful (in terms of performance) in our testing, however many more
// are, in principle, possible.
if
(
false
)
{
}
CALL_IF
(
8
,
8
,
256
)
CALL_IF
(
16
,
4
,
256
)
CALL_IF
(
8
,
4
,
128
)
CALL_IF
(
4
,
8
,
128
)
else
{
throw
std
::
runtime_error
(
"Unsupported shapes: MKN = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_k
)
+
", "
+
str
(
prob_n
)
+
"]"
+
", groupsize = "
+
str
(
groupsize
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
}
// namespace marlin
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
// Verify M
TORCH_CHECK
(
size_m
==
a
.
size
(
0
),
"Shape mismatch: a.size(0) = "
+
str
(
a
.
size
(
0
))
+
", size_m = "
+
str
(
size_m
));
// Verify K
TORCH_CHECK
(
size_k
==
a
.
size
(
1
),
"Shape mismatch: a.size(1) = "
+
str
(
a
.
size
(
1
))
+
", size_k = "
+
str
(
size_k
));
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
+
str
(
size_k
)
+
" is not divisible by tile_size = "
+
str
(
marlin
::
tile_size
));
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
+
str
(
b_q_weight
.
size
(
0
))
+
", size_k = "
+
str
(
size_k
)
+
", tile_size = "
+
str
(
marlin
::
tile_size
));
// Verify N
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales.size(1) = "
+
str
(
b_scales
.
size
(
1
))
+
", size_n = "
+
str
(
size_n
));
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
+
str
(
b_q_weight
.
size
(
1
))
+
" is not divisible by tile_size = "
+
str
(
marlin
::
tile_size
));
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
marlin
::
pack_factor_4bit
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
+
str
(
size_n
)
+
", actual_size_n = "
+
str
(
actual_size_n
));
// Verify A device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
// Verify B device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
// Verify scales device and strides
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
// Alloc C matrix
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize
if
(
b_scales
.
size
(
0
)
!=
1
)
{
TORCH_CHECK
(
size_k
%
b_scales
.
size
(
0
)
==
0
,
"size_k = "
+
str
(
size_k
)
+
", is not divisible by b_scales.size(0) = "
+
str
(
b_scales
.
size
(
0
)));
}
int
groupsize
=
b_scales
.
size
(
0
)
==
1
?
-
1
:
size_k
/
b_scales
.
size
(
0
);
// Verify groupsize
TORCH_CHECK
(
groupsize
==
-
1
||
groupsize
==
128
,
"Unexpected groupsize = "
+
str
(
groupsize
));
// Verify workspace size
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
+
str
(
size_n
)
+
", is not divisible by min_thread_n = "
+
str
(
marlin
::
min_thread_n
));
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
+
str
(
workspace
.
numel
())
+
" is below min_workspace_size = "
+
str
(
min_workspace_size
));
int
dev
=
a
.
get_device
();
marlin
::
marlin_cuda
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
groupsize
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
return
c
;
}
server/marlin/marlin_kernels/sparse/common/base.h
deleted
100644 → 0
View file @
9822d7f6
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace
marlin_24
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
template
<
int
M_
,
int
N_
,
int
K_
>
struct
ShapeBase
{
static
constexpr
int
M
=
M_
,
N
=
N_
,
K
=
K_
;
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragM
=
Vec
<
uint
,
1
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
}
// namespace marlin_24
server/marlin/marlin_kernels/sparse/common/mem.h
deleted
100644 → 0
View file @
9822d7f6
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
namespace
marlin_24
{
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred_zfill
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
,
const
bool
zfill
=
false
)
{
const
int
BYTES
=
16
;
int
src_in_bytes
=
(
zfill
?
0
:
BYTES
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
),
"r"
(
src_in_bytes
));
}
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
__device__
inline
void
ldsm4_m
(
FragM
&
frag_m
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_m
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
])
:
"r"
(
smem
));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4_t
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
}
// namespace marlin_24
server/marlin/marlin_kernels/sparse/common/mma.h
deleted
100644 → 0
View file @
9822d7f6
/*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "base.h"
#include <cudaTypedefs.h>
namespace
marlin_24
{
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
// | reduced performance on some future architectures
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
#define MMA_SP_INST \
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#else
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
#endif
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma_sp
(
const
FragB
&
a_frag0
,
const
FragB
&
a_frag1
,
const
FragA
&
frag_b
,
FragC
&
frag_c
,
FragM
&
frag_m
,
const
int
psel
)
{
const
uint32_t
*
a0
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag0
);
const
uint32_t
*
a1
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag1
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
const
uint32_t
*
e
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_m
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
(
psel
==
0
)
{
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
0
]),
"r"
(
b
[
2
]),
"r"
(
b
[
4
]),
"r"
(
b
[
6
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]),
"r"
(
e
[
0
]));
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;
\n
"
:
"=f"
(
c
[
4
]),
"=f"
(
c
[
5
]),
"=f"
(
c
[
6
]),
"=f"
(
c
[
7
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
1
]),
"r"
(
b
[
3
]),
"r"
(
b
[
5
]),
"r"
(
b
[
7
]),
"f"
(
c
[
4
]),
"f"
(
c
[
5
]),
"f"
(
c
[
6
]),
"f"
(
c
[
7
]),
"r"
(
e
[
0
]));
}
else
{
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
0
]),
"r"
(
b
[
2
]),
"r"
(
b
[
4
]),
"r"
(
b
[
6
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]),
"r"
(
e
[
0
]));
asm
volatile
(
MMA_SP_INST
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;
\n
"
:
"=f"
(
c
[
4
]),
"=f"
(
c
[
5
]),
"=f"
(
c
[
6
]),
"=f"
(
c
[
7
])
:
"r"
(
a0
[
0
]),
"r"
(
a1
[
0
]),
"r"
(
a0
[
1
]),
"r"
(
a1
[
1
]),
"r"
(
b
[
1
]),
"r"
(
b
[
3
]),
"r"
(
b
[
5
]),
"r"
(
b
[
7
]),
"f"
(
c
[
4
]),
"f"
(
c
[
5
]),
"f"
(
c
[
6
]),
"f"
(
c
[
7
]),
"r"
(
e
[
0
]));
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
__device__
__forceinline__
uint2
to_half4
(
float
c0
,
float
c1
,
float
c2
,
float
c3
)
{
uint2
r
;
asm
(
"{
\n\t
"
".reg .f16 a, b, c, d;
\n\t
"
"cvt.rn.f16.f32 a, %2;
\n\t
"
"cvt.rn.f16.f32 b, %3;
\n\t
"
"cvt.rn.f16.f32 c, %4;
\n\t
"
"cvt.rn.f16.f32 d, %5;
\n\t
"
"mov.b32 %0, {a, b};
\n\t
"
"mov.b32 %1, {c, d};
\n\t
"
"}"
:
"=r"
(
r
.
x
),
"=r"
(
r
.
y
)
:
"f"
(
c0
),
"f"
(
c1
),
"f"
(
c2
),
"f"
(
c3
));
return
r
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant_4bit
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant_8bit
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
__device__
inline
void
scale_floats
(
float
*
c0
,
float
*
c1
,
float
*
c2
,
float
*
c3
,
FragS
&
s0
,
float
*
c4
,
float
*
c5
,
float
*
c6
,
float
*
c7
,
FragS
&
s1
)
{
*
c0
=
__fmul_rn
(
*
c0
,
__half2float
(
s0
[
0
].
x
));
*
c1
=
__fmul_rn
(
*
c1
,
__half2float
(
s0
[
0
].
y
));
*
c2
=
__fmul_rn
(
*
c2
,
__half2float
(
s0
[
1
].
x
));
*
c3
=
__fmul_rn
(
*
c3
,
__half2float
(
s0
[
1
].
y
));
*
c4
=
__fmul_rn
(
*
c4
,
__half2float
(
s1
[
0
].
x
));
*
c5
=
__fmul_rn
(
*
c5
,
__half2float
(
s1
[
0
].
y
));
*
c6
=
__fmul_rn
(
*
c6
,
__half2float
(
s1
[
1
].
x
));
*
c7
=
__fmul_rn
(
*
c7
,
__half2float
(
s1
[
1
].
y
));
}
}
// namespace marlin_24
server/marlin/marlin_kernels/sparse/marlin_24_cuda_kernel.cu
deleted
100644 → 0
View file @
9822d7f6
/*
* Notice: This file was modified by Neuralmagic inc to include 8-bit support
*
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
* Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "common/base.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#else
#include "common/mem.h"
#include "common/mma.h"
#endif
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin_24
{
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static
constexpr
int
THREADS
=
256
;
static
constexpr
int
STAGES
=
4
;
static
constexpr
int
min_thread_n
=
128
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
const
int
num_bits
,
// weight bits
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin_24
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
meta
,
// 2bit metadata information about 2:4
// format on B
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
s
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{}
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
template
<
const
int
num_bits
,
// weight bits
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin_24
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
meta
,
// 2bit metadata information about 2:4
// format on B
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
s
,
// fp16 quantization scales of shape
// (k/groupsize)xn
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
// number of thread_k_blocks in k-dim
int
k_tiles
=
prob_k
/
32
/
thread_k_blocks
;
// number of thread_n_blocks in n-dim
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
// iters needed to cover all slices
int
iters
=
ceildiv
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts in
// the middle of group.
if
(
group_blocks
!=
-
1
)
iters
=
(
group_blocks
/
thread_k_blocks
)
*
ceildiv
(
iters
,
(
group_blocks
/
thread_k_blocks
));
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
// number of threadblock tiles in the current slice
int
slice_iters
;
// total number of active threadblocks in the current slice
int
slice_count
=
0
;
// index of threadblock in current slice; numbered bottom to top
int
slice_idx
;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
ceildiv
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
ceildiv
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
int
a_gl_stride
=
prob_k
/
8
;
// stride of the A matrix in global memory
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
32
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
32
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads //RLC: 2 * #warps k-dim
constexpr
int
a_sh_rd_delta_o
=
4
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
ceildiv
(
a_sh_stage
,
a_sh_wr_delta
);
constexpr
int
pack_factor
=
32
/
num_bits
;
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
int
m_gl_stride
=
2
*
prob_n
/
8
;
// (16*2*4 / 8) = 16
constexpr
int
m_sh_stride
=
(
16
*
thread_n_blocks
)
/
4
;
// #warps n-dim * threads/warp
int
m_gl_rd_delta_o
=
m_gl_stride
*
thread_k_blocks
;
int
m_gl_rd_delta_i
=
m_gl_stride
*
(
threads
/
m_sh_stride
);
constexpr
int
m_sh_wr_delta
=
threads
/
2
;
constexpr
int
m_sh_rd_delta
=
threads
/
2
;
constexpr
int
m_sh_stage
=
m_sh_stride
*
thread_k_blocks
;
constexpr
int
m_sh_iters
=
ceildiv
(
m_sh_stage
,
m_sh_wr_delta
);
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_sh_stage
=
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
4
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
int
m_gl_rd
=
m_gl_stride
*
(
threadIdx
.
x
/
(
m_sh_stride
))
+
(
threadIdx
.
x
%
(
m_sh_stride
));
m_gl_rd
+=
(
m_sh_stride
)
*
slice_col
;
m_gl_rd
+=
m_gl_rd_delta_o
*
slice_row
;
int
m_sh_wr
=
threadIdx
.
x
;
int
m_sh_rd
=
threadIdx
.
x
%
16
+
(
threadIdx
.
x
/
32
)
*
16
;
int
s_gl_rd
;
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
int
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_rd
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
if
(
group_blocks
!=
-
1
)
{
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
}
else
{
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
}
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
2
][
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
{
a_sh_rd_trans
[
0
][
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
a_sh_rd_trans
[
1
][
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
+
2
);
}
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
bool
m_sh_wr_pred
=
threadIdx
.
x
<
m_sh_wr_delta
;
const
int4
*
meta_ptr
[
m_sh_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
m_sh_iters
;
i
++
)
meta_ptr
[
i
]
=
meta
+
m_gl_rd_delta_i
*
i
+
m_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_s
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_m
=
sh_s
+
(
stages
*
s_sh_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
][
2
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragM
frag_m
[
2
][
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
int4
*
sh_meta_stage
=
sh_m
+
m_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
m_sh_iters
;
i
++
)
{
if
(
m_sh_wr_pred
)
cp_async4
(
&
sh_meta_stage
[
m_sh_wr_delta
*
i
+
m_sh_wr
],
meta_ptr
[
i
]);
meta_ptr
[
i
]
+=
m_gl_rd_delta_o
;
}
// Only fetch scales if this tile starts a new group
if
(
group_blocks
!=
-
1
&&
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
s_gl_rd
+=
s_gl_rd_delta
;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
// It may seem inefficient that we reload the groups for every sub-tile;
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
ldsm4
(
frag_a
[
k
%
2
][
i
][
0
],
&
sh_a_stage
[
a_sh_rd_trans
[
0
][
k
%
b_sh_wr_iters
][
i
]]);
ldsm4
(
frag_a
[
k
%
2
][
i
][
1
],
&
sh_a_stage
[
a_sh_rd_trans
[
1
][
k
%
b_sh_wr_iters
][
i
]]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
// Load meta with ldsm4
int4
*
sh_m_stage
=
sh_m
+
m_sh_stage
*
pipe
;
ldsm4_m
(
frag_m
[
k
%
2
][
0
],
&
sh_m_stage
[
m_sh_rd_delta
*
(
k
%
m_sh_iters
)
+
m_sh_rd
]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
if
constexpr
(
num_bits
==
4
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
frag_b0
=
dequant_4bit
(
b_quant
);
frag_b1
=
dequant_4bit
(
b_quant_shift
);
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
(
b_quant_0
);
frag_b1
=
dequant_8bit
(
b_quant_1
);
}
// If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight.
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma_sp
(
frag_b0
,
frag_b1
,
frag_a
[
k
%
2
][
i
][
0
],
frag_c
[
i
][
j
][
0
],
frag_m
[
k
%
2
][
j
/
2
],
j
%
2
);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
2
*
4
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
c_gl_stride
;
// 8 threads (e.g., 0,4,8,12,16,20,24,28)
int
c_gl_wr
=
2
*
c_gl_stride
*
(
threadIdx
.
x
%
4
)
+
8
*
(
threadIdx
.
x
/
32
)
+
(
threadIdx
.
x
%
32
)
/
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
col
=
2
*
((
threadIdx
.
x
%
32
)
%
4
);
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
col
+
(
i
%
2
)
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
col
+
(
i
%
2
)
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j2
=
0
;
j2
<
2
;
j2
++
)
{
#pragma unroll
for
(
int
j1
=
0
;
j1
<
4
;
j1
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
8
*
j1
+
2
*
j2
+
4
*
((
i
%
4
)
/
2
)
+
i
%
2
]
+=
__half2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[(
j2
*
4
+
j1
)]);
}
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j2
=
0
;
j2
<
2
;
j2
++
)
{
#pragma unroll
for
(
int
j1
=
0
;
j1
<
4
;
j1
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[(
j2
*
4
+
j1
)]
=
__float2half
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
8
*
j1
+
2
*
j2
+
4
*
((
i
%
4
)
/
2
)
+
i
%
2
]);
}
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
;
// RLC:
constexpr
int
c_sh_stride_2
=
2
*
c_sh_stride
+
2
;
// RLC:
constexpr
int
c_sh_stride_3
=
2
*
(
2
*
thread_n_blocks
)
+
2
;
// RLC:
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
c_sh_stride_2
*
((
threadIdx
.
x
%
32
)
%
4
)
+
((
threadIdx
.
x
%
32
)
/
4
);
// RLC:
c_sh_wr
+=
8
*
(
threadIdx
.
x
/
32
);
// 128/4(half4)
constexpr
int
c_sh_rd_delta
=
c_sh_stride_3
*
(
threads
/
(
2
*
2
*
thread_n_blocks
));
// RLC:
int
c_sh_rd
=
c_sh_stride_3
*
(
threadIdx
.
x
/
(
2
*
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
float
c2
,
float
c3
,
FragS
&
s0
,
float
c4
,
float
c5
,
float
c6
,
float
c7
,
FragS
&
s1
)
{
uint2
res
[
2
];
res
[
0
]
=
to_half4
(
c0
,
c1
,
c2
,
c3
);
res
[
1
]
=
to_half4
(
c4
,
c5
,
c6
,
c7
);
half2
*
tmp
=
(
half2
*
)
&
res
;
// for per-column quantization we finally apply the scale here
if
constexpr
(
group_blocks
==
-
1
&&
num_bits
==
4
)
{
tmp
[
0
]
=
__hmul2
(
tmp
[
0
],
s0
[
0
]);
tmp
[
1
]
=
__hmul2
(
tmp
[
1
],
s0
[
1
]);
tmp
[
2
]
=
__hmul2
(
tmp
[
2
],
s1
[
0
]);
tmp
[
3
]
=
__hmul2
(
tmp
[
3
],
s1
[
1
]);
}
((
int4
*
)
sh
)[
idx
]
=
*
((
int4
*
)
&
res
[
0
]);
};
// RLC: only warp 0 and 1 baseline example
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
int
wr
=
c_sh_wr
;
write
(
wr
,
frag_c
[
i
][
0
][
0
][
0
],
frag_c
[
i
][
1
][
0
][
0
],
frag_c
[
i
][
2
][
0
][
0
],
frag_c
[
i
][
3
][
0
][
0
],
frag_s
[
0
][
0
],
frag_c
[
i
][
0
][
0
][
2
],
frag_c
[
i
][
1
][
0
][
2
],
frag_c
[
i
][
2
][
0
][
2
],
frag_c
[
i
][
3
][
0
][
2
],
frag_s
[
0
][
2
]);
write
(
wr
+
c_sh_stride
,
frag_c
[
i
][
0
][
0
][
1
],
frag_c
[
i
][
1
][
0
][
1
],
frag_c
[
i
][
2
][
0
][
1
],
frag_c
[
i
][
3
][
0
][
1
],
frag_s
[
0
][
0
],
frag_c
[
i
][
0
][
0
][
3
],
frag_c
[
i
][
1
][
0
][
3
],
frag_c
[
i
][
2
][
0
][
3
],
frag_c
[
i
][
3
][
0
][
3
],
frag_s
[
0
][
2
]);
write
(
wr
+
4
*
c_sh_stride_2
,
frag_c
[
i
][
0
][
1
][
0
],
frag_c
[
i
][
1
][
1
][
0
],
frag_c
[
i
][
2
][
1
][
0
],
frag_c
[
i
][
3
][
1
][
0
],
frag_s
[
0
][
0
],
frag_c
[
i
][
0
][
1
][
2
],
frag_c
[
i
][
1
][
1
][
2
],
frag_c
[
i
][
2
][
1
][
2
],
frag_c
[
i
][
3
][
1
][
2
],
frag_s
[
0
][
2
]);
write
(
wr
+
4
*
c_sh_stride_2
+
c_sh_stride
,
frag_c
[
i
][
0
][
1
][
1
],
frag_c
[
i
][
1
][
1
][
1
],
frag_c
[
i
][
2
][
1
][
1
],
frag_c
[
i
][
3
][
1
][
1
],
frag_s
[
0
][
0
],
frag_c
[
i
][
0
][
1
][
3
],
frag_c
[
i
][
1
][
1
][
3
],
frag_c
[
i
][
2
][
1
][
3
],
frag_c
[
i
][
3
][
1
][
3
],
frag_s
[
0
][
2
]);
c_sh_wr
+=
8
*
c_sh_stride_2
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
ceildiv
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
zero_accums
();
wait_for_stage
();
fetch_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
};
start_pipes
();
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines have
// even length meaning that the next iteration will always start at index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
matmul
(
pipe
);
wait_for_stage
();
fetch_to_registers
(
pipe
+
1
,
(
pipe
+
1
)
%
stages
);
pipe
++
;
slice_iters
--
;
if
(
slice_iters
==
0
)
break
;
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
cp_async_fence
();
}
else
{
if
(
last
)
{
if
(
s_sh_wr_pred
)
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
s
[
s_gl_rd
]);
cp_async_fence
();
}
}
}
thread_block_reduce
();
if
constexpr
(
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
*
(
float4
*
)(
frag_s
)
=
*
(
float4
*
)(
&
sh_s
[
s_sh_rd
]);
}
}
else
{
if
(
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
*
(
float4
*
)(
frag_s
)
=
*
(
float4
*
)(
&
sh_s
[
s_sh_rd
]);
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
group_blocks
==
-
1
&&
num_bits
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
scale_floats
(
&
frag_c
[
i
][
0
][
0
][
0
],
&
frag_c
[
i
][
1
][
0
][
0
],
&
frag_c
[
i
][
2
][
0
][
0
],
&
frag_c
[
i
][
3
][
0
][
0
],
frag_s
[
0
][
0
],
&
frag_c
[
i
][
0
][
0
][
2
],
&
frag_c
[
i
][
1
][
0
][
2
],
&
frag_c
[
i
][
2
][
0
][
2
],
&
frag_c
[
i
][
3
][
0
][
2
],
frag_s
[
0
][
2
]);
scale_floats
(
&
frag_c
[
i
][
0
][
0
][
1
],
&
frag_c
[
i
][
1
][
0
][
1
],
&
frag_c
[
i
][
2
][
0
][
1
],
&
frag_c
[
i
][
3
][
0
][
1
],
frag_s
[
0
][
0
],
&
frag_c
[
i
][
0
][
0
][
3
],
&
frag_c
[
i
][
1
][
0
][
3
],
&
frag_c
[
i
][
2
][
0
][
3
],
&
frag_c
[
i
][
3
][
0
][
3
],
frag_s
[
0
][
2
]);
scale_floats
(
&
frag_c
[
i
][
0
][
1
][
0
],
&
frag_c
[
i
][
1
][
1
][
0
],
&
frag_c
[
i
][
2
][
1
][
0
],
&
frag_c
[
i
][
3
][
1
][
0
],
frag_s
[
0
][
0
],
&
frag_c
[
i
][
0
][
1
][
2
],
&
frag_c
[
i
][
1
][
1
][
2
],
&
frag_c
[
i
][
2
][
1
][
2
],
&
frag_c
[
i
][
3
][
1
][
2
],
frag_s
[
0
][
2
]);
scale_floats
(
&
frag_c
[
i
][
0
][
1
][
1
],
&
frag_c
[
i
][
1
][
1
][
1
],
&
frag_c
[
i
][
2
][
1
][
1
],
&
frag_c
[
i
][
3
][
1
][
1
],
frag_s
[
0
][
0
],
&
frag_c
[
i
][
0
][
1
][
3
],
&
frag_c
[
i
][
1
][
1
][
3
],
&
frag_c
[
i
][
2
][
1
][
3
],
&
frag_c
[
i
][
3
][
1
][
3
],
frag_s
[
0
][
2
]);
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
#pragma unroll
for
(
int
i
=
0
;
i
<
m_sh_iters
;
i
++
)
meta_ptr
[
i
]
+=
(
m_sh_stride
)
-
m_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
m_sh_iters
;
i
++
)
meta_ptr
[
i
]
-=
m_gl_stride
;
}
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
start_pipes
();
}
}
}
}
#endif
#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS) { \
cudaFuncSetAttribute( \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
C_ptr, s_ptr, prob_n, \
prob_m, prob_k, locks); \
}
void
marlin_cuda_2_4
(
const
void
*
A
,
const
void
*
B
,
const
void
*
meta
,
void
*
C
,
void
*
s
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
groupsize
=
-
1
,
int
dev
=
0
,
cudaStream_t
stream
=
0
,
int
thread_k
=
-
1
,
int
thread_m
=
-
1
,
int
sms
=
-
1
,
int
max_par
=
16
)
{
int
tot_n
=
prob_n
;
int
tot_n_blocks
=
ceildiv
(
tot_n
,
16
);
int
pad
=
16
*
tot_n_blocks
-
tot_n
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
TORCH_CHECK
(
sms
>
0
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
thread_k
==
-
1
||
thread_m
==
-
1
)
{
if
(
prob_n
<=
16
)
{
// For small batchizes, better partitioningif is slightly more important
// than better compute utilization
thread_k
=
128
;
thread_m
=
128
;
}
else
if
(
prob_n
<=
256
)
{
thread_k
=
64
;
thread_m
=
256
;
}
else
{
thread_k
=
32
;
thread_m
=
512
;
}
}
int
thread_k_blocks
=
thread_k
/
32
;
// 2:4 version with m16n8k32 instruction
int
thread_m_blocks
=
thread_m
/
16
;
int
group_blocks
=
(
groupsize
==
-
1
)
?
-
1
:
groupsize
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_m
%
thread_m
==
0
,
"prob_m = "
,
prob_m
,
" is not divisible by thread_m = "
,
thread_m
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
if
(
group_blocks
!=
-
1
)
{
TORCH_CHECK
((
prob_k
/
2
)
%
group_blocks
==
0
,
"prob_k/2 = "
,
prob_k
/
2
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
const
int4
*
meta_ptr
=
(
const
int4
*
)
meta
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
constexpr
int
max_m_blocks
=
4
;
int
*
locks
=
(
int
*
)
workspace
;
for
(
int
i
=
0
;
i
<
tot_n_blocks
;
i
+=
max_m_blocks
)
{
int
thread_n_blocks
=
tot_n_blocks
-
i
;
prob_n
=
tot_n
-
16
*
i
;
int
par
=
1
;
if
(
thread_n_blocks
>
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_n_blocks
-
pad
)
/
(
max_m_blocks
*
16
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_n
=
(
max_m_blocks
*
16
)
*
par
;
i
+=
max_m_blocks
*
(
par
-
1
);
thread_n_blocks
=
max_m_blocks
;
}
// For compilation speed, we only define the kernel configurations that have
// seemed useful (in terms of performance) in our testing, however many more
// are, in principle, possible.
// the false is start of the CALL_IF macros
if
(
false
)
{
}
// BMxBNxBK, group
// 4-bit
CALL_IF_2_4
(
4
,
8
,
1
,
4
,
-
1
)
// e.g., 16x128x128
CALL_IF_2_4
(
4
,
8
,
1
,
4
,
4
)
// e.g., 16x128x128, 64
CALL_IF_2_4
(
4
,
16
,
1
,
2
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
4
,
16
,
1
,
2
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
4
,
16
,
2
,
2
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
4
,
16
,
2
,
2
,
4
)
CALL_IF_2_4
(
4
,
16
,
3
,
2
,
-
1
)
CALL_IF_2_4
(
4
,
16
,
3
,
2
,
4
)
CALL_IF_2_4
(
4
,
16
,
4
,
2
,
-
1
)
CALL_IF_2_4
(
4
,
16
,
4
,
2
,
4
)
CALL_IF_2_4
(
4
,
32
,
1
,
1
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
4
,
32
,
1
,
1
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
4
,
32
,
2
,
1
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
4
,
32
,
2
,
1
,
4
)
CALL_IF_2_4
(
4
,
32
,
3
,
1
,
-
1
)
CALL_IF_2_4
(
4
,
32
,
3
,
1
,
4
)
CALL_IF_2_4
(
4
,
32
,
4
,
1
,
-
1
)
CALL_IF_2_4
(
4
,
32
,
4
,
1
,
4
)
// 8-bit
CALL_IF_2_4
(
8
,
8
,
1
,
4
,
-
1
)
// e.g., 16x128x128
CALL_IF_2_4
(
8
,
8
,
1
,
4
,
4
)
// e.g., 16x128x128, 64
CALL_IF_2_4
(
8
,
16
,
1
,
2
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
8
,
16
,
1
,
2
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
8
,
16
,
2
,
2
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
8
,
16
,
2
,
2
,
4
)
CALL_IF_2_4
(
8
,
16
,
3
,
2
,
-
1
)
CALL_IF_2_4
(
8
,
16
,
3
,
2
,
4
)
CALL_IF_2_4
(
8
,
16
,
4
,
2
,
-
1
)
CALL_IF_2_4
(
8
,
16
,
4
,
2
,
4
)
CALL_IF_2_4
(
8
,
32
,
1
,
1
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
8
,
32
,
1
,
1
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
8
,
32
,
2
,
1
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
8
,
32
,
2
,
1
,
4
)
CALL_IF_2_4
(
8
,
32
,
3
,
1
,
-
1
)
CALL_IF_2_4
(
8
,
32
,
3
,
1
,
4
)
CALL_IF_2_4
(
8
,
32
,
4
,
1
,
-
1
)
CALL_IF_2_4
(
8
,
32
,
4
,
1
,
4
)
else
{
throw
std
::
runtime_error
(
"Unsupported shapes: MKN = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_k
)
+
", "
+
str
(
prob_n
)
+
"]"
+
", groupsize = "
+
str
(
groupsize
)
+
", thread_m_blocks = "
+
str
(
thread_m_blocks
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
A_ptr
+=
16
*
thread_n_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_n_blocks
*
(
prob_m
/
8
)
*
par
;
}
}
}
// namespace marlin_24
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify M
TORCH_CHECK
(
size_m
==
a
.
size
(
0
),
"Shape mismatch: a.size(0) = "
+
str
(
a
.
size
(
0
))
+
", size_m = "
+
str
(
size_m
));
// Verify K
TORCH_CHECK
(
size_k
==
a
.
size
(
1
),
"Shape mismatch: a.size(1) = "
+
str
(
a
.
size
(
1
))
+
", size_k = "
+
str
(
size_k
));
TORCH_CHECK
(
size_k
%
marlin_24
::
tile_size
==
0
,
"size_k = "
+
str
(
size_k
)
+
" is not divisible by tile_size = "
+
str
(
marlin_24
::
tile_size
));
TORCH_CHECK
((
size_k
/
marlin_24
::
tile_size
/
2
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
+
str
(
b_q_weight
.
size
(
0
))
+
", size_k = "
+
str
(
size_k
)
+
", tile_size = "
+
str
(
marlin_24
::
tile_size
));
// Verify N
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales.size(1) = "
+
str
(
b_scales
.
size
(
1
))
+
", size_n = "
+
str
(
size_n
));
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin_24
::
tile_size
==
0
,
"b_q_weight.size(1) = "
+
str
(
b_q_weight
.
size
(
1
))
+
" is not divisible by tile_size = "
+
str
(
marlin_24
::
tile_size
));
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin_24
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
+
str
(
size_n
)
+
", actual_size_n = "
+
str
(
actual_size_n
));
// Verify meta
TORCH_CHECK
(
b_meta
.
size
(
0
)
==
size_k
/
8
/
2
/
2
,
"b_meta.size(0) = "
,
b_meta
.
size
(
0
),
" is not size_k / 8 / 2 / 2 = "
,
size_k
/
8
/
2
/
2
);
TORCH_CHECK
(
b_meta
.
size
(
1
)
==
size_n
*
2
,
"b_meta.size(1) = "
,
b_meta
.
size
(
1
),
" is not size_n * 2 = "
,
size_n
*
2
);
// Verify A device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
// Verify B device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
// Verify b_meta device and strides
TORCH_CHECK
(
b_meta
.
device
().
is_cuda
(),
"b_meta is not on GPU"
);
TORCH_CHECK
(
b_meta
.
is_contiguous
(),
"b_meta is not contiguous"
);
// Verify scales device and strides
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
// Alloc C matrix
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
int
thread_k
=
-
1
;
int
thread_m
=
-
1
;
int
sms
=
-
1
;
int
max_par
=
marlin_24
::
max_par
;
int
groupsize
=
-
1
;
if
(
b_scales
.
size
(
0
)
>
1
)
{
TORCH_CHECK
(
size_k
%
b_scales
.
size
(
0
)
==
0
,
"size_k = "
+
str
(
size_k
)
+
", is not divisible by b_scales.size(0) = "
+
str
(
b_scales
.
size
(
0
)));
groupsize
=
size_k
/
b_scales
.
size
(
0
);
groupsize
/=
2
;
// Because of 24
}
// Verify groupsize
TORCH_CHECK
(
groupsize
==
-
1
||
groupsize
==
64
,
"Unexpected groupsize = "
+
str
(
groupsize
));
// Verify workspace size
TORCH_CHECK
(
size_n
%
marlin_24
::
min_thread_n
==
0
,
"size_n = "
+
str
(
size_n
)
+
", is not divisible by min_thread_n = "
+
str
(
marlin_24
::
min_thread_n
));
int
min_workspace_size
=
(
size_n
/
marlin_24
::
min_thread_n
)
*
marlin_24
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
+
str
(
workspace
.
numel
())
+
" is below min_workspace_size = "
+
str
(
min_workspace_size
));
int
dev
=
a
.
get_device
();
marlin_24
::
marlin_cuda_2_4
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
b_meta
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
size_n
,
size_m
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
groupsize
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_m
,
sms
,
max_par
);
return
c
;
}
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment