Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
877aec85
Unverified
Commit
877aec85
authored
Apr 09, 2025
by
Yuhao Tsui
Committed by
GitHub
Apr 09, 2025
Browse files
Merge branch 'kvcache-ai:main' into main
parents
84164f58
9037bf30
Changes
251
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4098 additions
and
15 deletions
+4098
-15
csrc/balance_serve/sched/utils/statistics.hpp
csrc/balance_serve/sched/utils/statistics.hpp
+77
-0
csrc/balance_serve/sched/utils/timer.hpp
csrc/balance_serve/sched/utils/timer.hpp
+132
-0
csrc/custom_marlin/__init__.py
csrc/custom_marlin/__init__.py
+0
-0
csrc/custom_marlin/binding.cpp
csrc/custom_marlin/binding.cpp
+44
-0
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
+2034
-0
csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh
csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh
+76
-0
csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
+77
-0
csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu
csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu
+350
-0
csrc/custom_marlin/gptq_marlin/ops.h
csrc/custom_marlin/gptq_marlin/ops.h
+24
-0
csrc/custom_marlin/setup.py
csrc/custom_marlin/setup.py
+25
-0
csrc/custom_marlin/test_cuda_graph.py
csrc/custom_marlin/test_cuda_graph.py
+335
-0
csrc/custom_marlin/utils/__init__.py
csrc/custom_marlin/utils/__init__.py
+0
-0
csrc/custom_marlin/utils/format24.py
csrc/custom_marlin/utils/format24.py
+308
-0
csrc/custom_marlin/utils/marlin_24_perms.py
csrc/custom_marlin/utils/marlin_24_perms.py
+65
-0
csrc/custom_marlin/utils/marlin_perms.py
csrc/custom_marlin/utils/marlin_perms.py
+65
-0
csrc/custom_marlin/utils/marlin_utils.py
csrc/custom_marlin/utils/marlin_utils.py
+234
-0
csrc/custom_marlin/utils/quant_utils.py
csrc/custom_marlin/utils/quant_utils.py
+195
-0
csrc/ktransformers_ext/CMakeLists.txt
csrc/ktransformers_ext/CMakeLists.txt
+57
-15
csrc/ktransformers_ext/bench/bench_attention.py
csrc/ktransformers_ext/bench/bench_attention.py
+0
-0
csrc/ktransformers_ext/bench/bench_attention_torch.py
csrc/ktransformers_ext/bench/bench_attention_torch.py
+0
-0
No files found.
csrc/balance_serve/sched/utils/statistics.hpp
0 → 100644
View file @
877aec85
#ifndef STATISTICS_HPP
#define STATISTICS_HPP
#include <chrono>
#include <iostream>
#include <string>
#include <unordered_map>
class
Statistics
{
public:
// Increment the counter for a given key by a specified value (default is 1)
void
increment_counter
(
const
std
::
string
&
key
,
int64_t
value
=
1
)
{
counters_
[
key
]
+=
value
;
}
int64_t
&
get_counter
(
const
std
::
string
&
key
)
{
return
counters_
[
key
];
}
// Start the timer for a given key
void
start_timer
(
const
std
::
string
&
key
)
{
active_timers_
[
key
]
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// Stop the timer for a given key and update the total time and count
void
stop_timer
(
const
std
::
string
&
key
)
{
auto
start_it
=
active_timers_
.
find
(
key
);
if
(
start_it
!=
active_timers_
.
end
())
{
auto
duration
=
std
::
chrono
::
high_resolution_clock
::
now
()
-
start_it
->
second
;
timings_
[
key
].
total_time
+=
duration
;
timings_
[
key
].
count
+=
1
;
active_timers_
.
erase
(
start_it
);
}
else
{
// Handle error: stop_timer called without a matching start_timer
std
::
cerr
<<
"Warning: stop_timer called for key '"
<<
key
<<
"' without a matching start_timer.
\n
"
;
}
}
// Print out the collected statistical information
void
report
()
const
{
std
::
cout
<<
"Counters:
\n
"
;
for
(
const
auto
&
kv
:
counters_
)
{
std
::
cout
<<
" "
<<
kv
.
first
<<
": "
<<
kv
.
second
<<
"
\n
"
;
}
std
::
cout
<<
"
\n
Timers:
\n
"
;
for
(
const
auto
&
kv
:
timings_
)
{
std
::
cout
<<
" "
<<
kv
.
first
<<
": count = "
<<
kv
.
second
.
count
<<
", total_time = "
<<
kv
.
second
.
total_time
.
count
()
<<
"s"
<<
", average_time = "
<<
(
kv
.
second
.
count
>
0
?
kv
.
second
.
total_time
.
count
()
/
kv
.
second
.
count
:
0
)
<<
"s
\n
"
;
}
}
private:
// Mapping from key to counter
std
::
unordered_map
<
std
::
string
,
int64_t
>
counters_
;
// Struct to hold timing information for a key
struct
TimingInfo
{
int64_t
count
=
0
;
std
::
chrono
::
duration
<
double
>
total_time
=
std
::
chrono
::
duration
<
double
>::
zero
();
};
// Mapping from key to timing information
std
::
unordered_map
<
std
::
string
,
TimingInfo
>
timings_
;
// Mapping from key to the start time of active timers
std
::
unordered_map
<
std
::
string
,
std
::
chrono
::
high_resolution_clock
::
time_point
>
active_timers_
;
};
#endif // STATISTICS_HPP
csrc/balance_serve/sched/utils/timer.hpp
0 → 100644
View file @
877aec85
#pragma once
#include "readable_number.hpp"
#include <cassert>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
inline
std
::
string
doubleToStringR2
(
double
value
)
{
std
::
stringstream
stream
;
stream
<<
std
::
fixed
<<
std
::
setprecision
(
2
)
<<
value
;
return
stream
.
str
();
}
class
Timer
{
public:
std
::
string
name
;
bool
tmp_timer
=
false
;
Timer
()
{}
Timer
(
std
::
string
name
)
:
name
(
name
),
tmp_timer
(
true
)
{
start
();
}
~
Timer
()
{
if
(
tmp_timer
)
{
std
::
cout
<<
name
<<
" "
<<
elapsedMs
()
<<
" ms"
<<
std
::
endl
;
}
}
void
start
()
{
m_startTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
assert
(
m_isRunning
==
false
);
m_isRunning
=
true
;
}
void
stop
()
{
m_endTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
assert
(
m_isRunning
==
true
);
m_isRunning
=
false
;
m_runningNs
+=
elapsedNs
();
}
double
elapsedNs
()
{
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
endTime
;
if
(
m_isRunning
)
{
endTime
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
else
{
endTime
=
m_endTime
;
}
return
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
endTime
-
m_startTime
)
.
count
();
}
void
printElapsedMilliseconds
()
{
std
::
cout
<<
elapsedNs
()
/
1e6
<<
" ms"
<<
std
::
endl
;
}
static
std
::
string
ns_to_string
(
double
duration
)
{
auto
nano_sec
=
duration
;
if
(
nano_sec
>=
1000
)
{
auto
mirco_sec
=
nano_sec
/
1000.0
;
if
(
mirco_sec
>=
1000
)
{
auto
milli_sec
=
mirco_sec
/
1000.0
;
if
(
milli_sec
>=
1000
)
{
auto
seconds
=
milli_sec
/
1000.0
;
if
(
seconds
>=
60.0
)
{
auto
minutes
=
seconds
/
60.0
;
if
(
minutes
>=
60.0
)
{
auto
hours
=
minutes
/
60.0
;
return
doubleToStringR2
(
hours
)
+
" h"
;
}
else
{
return
doubleToStringR2
(
minutes
)
+
" min"
;
}
}
else
{
return
doubleToStringR2
(
seconds
)
+
" sec"
;
}
}
else
{
return
doubleToStringR2
(
milli_sec
)
+
" ms"
;
}
}
else
{
return
doubleToStringR2
(
mirco_sec
)
+
" us"
;
}
}
else
{
return
doubleToStringR2
(
nano_sec
)
+
" ns"
;
}
}
double
runningTimeNs
()
{
return
m_runningNs
;
}
std
::
string
runningTime
()
{
auto
duration
=
m_runningNs
;
return
ns_to_string
(
duration
);
}
std
::
string
elapsedTime
()
{
return
ns_to_string
(
elapsedNs
());
}
double
elapsedMs
()
{
return
elapsedNs
()
/
1e6
;
}
std
::
string
report_throughput
(
size_t
op_cnt
)
{
double
ops
=
op_cnt
/
elapsedMs
()
*
1000
;
return
readable_number
(
ops
)
+
"op/s"
;
}
void
merge
(
Timer
&
other
)
{
assert
(
m_isRunning
==
false
);
assert
(
other
.
m_isRunning
==
false
);
m_runningNs
+=
other
.
runningTimeNs
();
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
m_startTime
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
m_endTime
;
bool
m_isRunning
=
false
;
double
m_runningNs
=
0.0
;
};
class
Counter
{
public:
Counter
()
{}
std
::
map
<
std
::
string
,
size_t
>
counters
;
void
inc
(
const
char
*
name
,
size_t
num
)
{
counters
[
name
]
+=
num
;
};
void
print
()
{
for
(
auto
&
p
:
counters
)
{
std
::
cout
<<
p
.
first
<<
" : "
<<
p
.
second
<<
std
::
endl
;
}
};
};
csrc/custom_marlin/__init__.py
0 → 100644
View file @
877aec85
csrc/custom_marlin/binding.cpp
0 → 100644
View file @
877aec85
/**
* @Description :
* @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "gptq_marlin/ops.h"
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/torch.h>
// namespace py = pybind11;
PYBIND11_MODULE
(
vLLMMarlin
,
m
)
{
/*m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k
data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize
iq4_xs data.", py::arg("data"), py::arg("blk_size"), py::arg("device"));*/
m
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"Function to perform GEMM using Marlin quantization."
,
py
::
arg
(
"a"
),
py
::
arg
(
"b_q_weight"
),
py
::
arg
(
"b_scales"
),
py
::
arg
(
"g_idx"
),
py
::
arg
(
"perm"
),
py
::
arg
(
"workspace"
),
py
::
arg
(
"num_bits"
),
py
::
arg
(
"size_m_tensor"
),
py
::
arg
(
"size_m"
),
py
::
arg
(
"size_n"
),
py
::
arg
(
"size_k"
),
py
::
arg
(
"sms"
),
py
::
arg
(
"is_k_full"
));
m
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
}
\ No newline at end of file
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
0 → 100644
View file @
877aec85
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
/*
* Adapted from
* https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
gptq_marlin
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{}
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
typename
scalar_t
>
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
typename
scalar_t
>
__device__
inline
void
ldsm4
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit
<
half
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_4bit
<
nv_bfloat16
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit
<
half
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit
<
nv_bfloat16
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Same as above, but for act_order (each K is multiplied individually)
template
<
typename
scalar_t
>
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_1
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_2
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_3
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_4
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_2
)[
i
];
scalar_t2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
// Given 2 floats multiply by 2 scales (halves)
template
<
typename
scalar_t
>
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be
// visible globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
default_threads
;
int
rest
=
size_k
%
default_threads
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
default_threads
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__device__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m, should be divisible by (16 * thread_m_blocks) if bigger than that
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
constexpr
int
pack_factor
=
32
/
num_bits
;
// int prob_m = *prob_m_ptr;
// const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
// constexpr int thread_m_blocks = template_thread_m_blocks;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
if
constexpr
(
!
has_act_order
&&
group_blocks
!=
-
1
)
{
if
(
group_blocks
>=
thread_k_blocks
)
{
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters
=
(
group_blocks
/
thread_k_blocks
)
*
div_ceil
(
iters
,
(
group_blocks
/
thread_k_blocks
));
}
}
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
;
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
}
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
}
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
{
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_s
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
{
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
}
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
if
constexpr
(
has_act_order
)
{
// Fetch g_idx thread-block portion
int
full_pipe
=
a_off
;
int
cur_k
=
slice_k_start_shared_fetch
+
tb_k
*
full_pipe
;
if
(
cur_k
<
prob_k
&&
cur_k
<
slice_k_finish
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
const
*
cur_g_idx_stage_ptr
=
reinterpret_cast
<
int4
const
*>
(
&
g_idx
[
cur_k
]);
if
(
threadIdx
.
x
<
g_idx_stage
)
{
cp_async4_pred
(
&
sh_g_idx_stage
[
threadIdx
.
x
],
&
cur_g_idx_stage_ptr
[
threadIdx
.
x
]);
}
}
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
ldsm4
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
if
constexpr
(
!
has_act_order
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
group_id_1
=
sh_g_idx_int_ptr
[
0
];
int
group_id_2
=
sh_g_idx_int_ptr
[
tb_k
-
1
];
is_same_group
[
pipe
]
=
group_id_1
==
group_id_2
;
same_group_id
[
pipe
]
=
group_id_1
;
};
auto
fetch_scales_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
!
has_act_order
)
{
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
}
return
;
}
// Act-order case
// Determine K of the "current" thread-block
int
cur_k
=
slice_k_start
+
tb_k
*
full_pipe
;
if
(
cur_k
>=
prob_k
||
cur_k
>=
slice_k_finish
)
{
return
;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k
=
0
;
// Progress to current iteration
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
int
warp_row
=
warp_id
/
n_warps
;
int
warp_col
=
warp_id
%
n_warps
;
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
/*slice_n_offset +*/
(
act_s_col_warp_stride
*
warp_col
)
+
(
th_id
/
4
)
*
act_s_col_stride
;
if
(
is_same_group
[
pipe
])
{
if
(
k
%
2
==
0
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
sh_s
[(
same_group_id
[
pipe
]
-
sh_first_group_id
)
*
s_sh_stride
+
s_col_shift
];
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[(
k
-
1
)
%
2
][
0
][
0
])));
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])));
}
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
constexpr
int
k_frag_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
// Tensor core offsets per thread
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
actual_k
=
cur_k
+
k_frag_offsets
[
i
];
int
group_id
=
sh_g_idx_int_ptr
[
actual_k
];
int
rel_group_id
=
group_id
-
sh_first_group_id
;
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
sh_s
[
rel_group_id
*
s_sh_stride
+
s_col_shift
];
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
if
constexpr
(
num_bits
==
4
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
}
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
Dtype
::
num2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
Dtype
::
float2num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2num
(
c0
),
Dtype
::
float2num
(
c1
));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
4
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
((
scalar_t2
*
)
sh
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
C
[
c_gl_wr
]
=
sh
[
c_sh_rd
];
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
if
(
has_act_order
&&
i
==
0
)
{
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
else
{
if
(
last
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
}
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
else
{
if
(
last
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
if
constexpr
(
has_act_order
)
{
slice_k_start
=
tb_k
*
slice_row
;
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
template_thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__
void
Marlin_wrapper
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
const
int
*
__restrict__
prob_m_ptr
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
)
{
int
prob_m
=
*
prob_m_ptr
;
const
int
thread_m_blocks
=
min
(
div_ceil
(
prob_m
,
16
),
template_thread_m_blocks
);
if
(
prob_m
>
16
*
thread_m_blocks
)
prob_m
=
(
16
*
thread_m_blocks
)
*
div_ceil
(
prob_m
,
(
16
*
thread_m_blocks
));
/*if (blockIdx.x == 0 && threadIdx.x == 0)
printf("marlin prob_m %d\n", prob_m);*/
if
(
thread_m_blocks
==
1
)
{
Marlin
<
scalar_t
,
num_bits
,
threads
,
1
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
scales_ptr
,
g_idx
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
);
}
else
if
(
thread_m_blocks
==
2
)
{
Marlin
<
scalar_t
,
num_bits
,
threads
,
2
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
scales_ptr
,
g_idx
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
);
}
else
if
(
thread_m_blocks
==
3
)
{
Marlin
<
scalar_t
,
num_bits
,
threads
,
3
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
scales_ptr
,
g_idx
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
);
}
else
if
(
thread_m_blocks
==
4
)
{
Marlin
<
scalar_t
,
num_bits
,
threads
,
4
,
thread_n_blocks
,
thread_k_blocks
,
stages
,
has_act_order
,
group_blocks
>
(
A
,
B
,
C
,
scales_ptr
,
g_idx
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
);
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin_wrapper<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin_wrapper<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m_ptr, prob_n, \
prob_k, locks); \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// {128, 128, 256},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_k
=
th_config
.
thread_k
;
// Get max scale groups per thread-block
int
tb_groups
;
if
(
group_size
==
-
1
)
{
tb_groups
=
1
;
}
else
if
(
group_size
==
0
)
{
tb_groups
=
div_ceil
(
tb_k
,
32
);
// Worst case is 32 group size
}
else
{
tb_groups
=
div_ceil
(
tb_k
,
group_size
);
}
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_stages
;
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
// zbx: too ugly
// origin
/*while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}*/
// refactor
tb_max_m
*=
std
::
min
(
m_blocks
,
max_m_blocks
);
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}
};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
*
prob_m_ptr
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}
};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
group_size
!=
-
1
);
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
else
{
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
}
else
{
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
}
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
int
*
locks
=
(
int
*
)
workspace
;
if
(
has_act_order
)
{
// Permute A columns
int
block_rows
=
div_ceil
(
prob_m
,
blocks
);
permute_cols_kernel
<<
<
blocks
,
default_threads
,
0
,
stream
>>
>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
prob_m
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
}
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without
// any padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
// Define kernel configurations
#define undefined_error \
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks));
/* std::cout << "MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks) << std::endl;*/
/*if (false) {
}
// CALL_IF(4, 32, 2, 256)
// CALL_IF(4, 16, 4, 256)
__CALL_IF(4, 1, 16, 4, false, 4, 256)
__CALL_IF(4, 2, 16, 4, false, 4, 256)
// CALL_IF(4, 8, 8, 256)
__CALL_IF(4, 1, 8, 8, false, 4, 256)
__CALL_IF(4, 2, 8, 8, false, 4, 256)
// CALL_IF(4, 16, 4, 128)
__CALL_IF(4, 1, 16, 4, false, 4, 128)
__CALL_IF(4, 2, 16, 4, false, 4, 128)
// CALL_IF(4, 8, 8, 128)
__CALL_IF(4, 1, 8, 8, false, 4, 128)
__CALL_IF(4, 2, 8, 8, false, 4, 128)
else {undefined_error}*/
if
(
num_bits
==
4
&&
num_threads
==
256
)
{
if
(
false
)
{
}
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
4
,
8
,
8
,
256
)
else
{
undefined_error
}
}
else
if
(
num_bits
==
4
&&
num_threads
==
128
)
{
if
(
false
)
{
}
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
16
,
4
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
else
{
undefined_error
}
}
// else if (num_bits == 8 && num_threads == 256)
// {
// if (false) {
// }
// CALL_IF(8, 32, 2, 256)
// CALL_IF(8, 16, 4, 256)
// CALL_IF(8, 8, 8, 256)
// else {
// undefined_error
// }
// }
// else if (num_bits == 8 && num_threads == 128)
// {
// if (false) {
// }
// CALL_IF(8, 8, 4, 128)
// CALL_IF(8, 16, 4, 128)
// CALL_IF(8, 4, 8, 128)
// else {
// undefined_error
// }
// }
else
{
undefined_error
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
size_m_tensor
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int
sms
,
bool
is_k_full
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
", size_m = "
,
size_m
);
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
"Shape mismatch: a.size(1) = "
,
a
.
size
(
1
),
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
gptq_marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// Alloc buffers
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
// int sms = -1; //zbx
// Verify g_idx and perm
TORCH_CHECK
((
g_idx
.
size
(
0
)
==
0
&&
perm
.
size
(
0
)
==
0
)
||
(
g_idx
.
size
(
0
)
==
size_k
&&
perm
.
size
(
0
)
==
size_k
),
"Unexpected g_idx.size(0) = "
,
g_idx
.
size
(
0
),
" and perm.size(0) = "
,
perm
.
size
(
0
),
", where size_k = "
,
size_k
);
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
2
,
"b_scales rank = "
,
b_rank
,
" is not 2"
);
TORCH_CHECK
(
b_scales
.
size
(
1
)
==
size_n
,
"b_scales dim 1 = "
,
b_scales
.
size
(
1
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
0
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by num_groups = "
,
num_groups
);
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
0
;
}
}
else
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
-
1
;
}
}
// Verify workspace size
TORCH_CHECK
(
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
gptq_marlin
::
marlin_mm_f16i4
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m_tensor
.
data_ptr
<
int
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
gptq_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m_tensor
.
data_ptr
<
int
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
return
c
;
}
#endif
\ No newline at end of file
csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh
0 → 100644
View file @
877aec85
// Adapted from
// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin
// Copyrigth 2024 The vLLM team.
// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace
gptq_marlin
{
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static
constexpr
int
default_threads
=
256
;
static
constexpr
int
pipe_stages
=
4
;
// 4 pipeline stages fit into shared memory
static
constexpr
int
min_thread_n
=
64
;
static
constexpr
int
min_thread_k
=
64
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
#endif
}
// namespace gptq_marlin
\ No newline at end of file
csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh
0 → 100644
View file @
877aec85
// Adapted from
// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin
// Copyrigth 2024 The vLLM team.
// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
class
ScalarType
{};
template
<
>
class
ScalarType
<
half
>
{
public:
using
scalar_t
=
half
;
using
scalar_t2
=
half2
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
static
__device__
float
inline
num2float
(
const
half
x
)
{
return
__half2float
(
x
);
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
static
__device__
half2
inline
nums2num2
(
const
half
x1
,
const
half
x2
)
{
return
__halves2half2
(
x1
,
x2
);
}
static
__host__
__device__
half
inline
float2num
(
const
float
x
)
{
return
__float2half
(
x
);
}
};
template
<
>
class
ScalarType
<
nv_bfloat16
>
{
public:
using
scalar_t
=
nv_bfloat16
;
using
scalar_t2
=
nv_bfloat162
;
using
FragA
=
Vec
<
nv_bfloat162
,
4
>
;
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
#endif
};
}
// namespace gptq_marlin
#endif
\ No newline at end of file
csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu
0 → 100644
View file @
877aec85
#include "gptq_marlin.cuh"
namespace
gptq_marlin
{
static
constexpr
int
repack_stages
=
8
;
static
constexpr
int
repack_threads
=
256
;
static
constexpr
int
tile_k_size
=
tile_size
;
static
constexpr
int
tile_n_size
=
tile_k_size
*
4
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
int
finish_k_tile
=
min
(
start_k_tile
+
block_k_tiles
,
k_tiles
);
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
repack_stages
-
2
>
();
__syncthreads
();
};
extern
__shared__
int4
sh
[];
constexpr
int
perm_size
=
tile_k_size
/
4
;
int4
*
sh_perm_ptr
=
sh
;
int4
*
sh_pipe_ptr
=
sh_perm_ptr
;
if
constexpr
(
has_perm
)
{
sh_pipe_ptr
+=
perm_size
;
}
constexpr
int
tile_ints
=
tile_k_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_size
/
4
;
constexpr
int
stage_k_threads
=
has_perm
?
tile_k_size
:
tile_ints
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
load_perm_to_shared
=
[
&
](
int
k_tile_id
)
{
int
first_k_int4
=
(
k_tile_id
*
tile_k_size
)
/
4
;
int4
const
*
perm_int4_ptr
=
reinterpret_cast
<
int4
const
*>
(
perm_ptr
);
if
(
threadIdx
.
x
<
perm_size
)
{
sh_perm_ptr
[
threadIdx
.
x
]
=
perm_int4_ptr
[
first_k_int4
+
threadIdx
.
x
];
}
__syncthreads
();
};
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
cp_async_fence
();
return
;
}
int
first_n
=
n_tile_id
*
tile_n_size
;
int4
*
sh_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
if
constexpr
(
has_perm
)
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
uint32_t
const
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
int
src_k
=
sh_perm_int_ptr
[
k_id
];
int
src_k_packed
=
src_k
/
pack_factor
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[
src_k_packed
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
else
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k_packed
=
first_k
/
pack_factor
;
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
cp_async_fence
();
};
auto
repack_tile
=
[
&
](
int
pipe
,
int
k_tile_id
,
int
n_tile_id
)
{
if
(
n_tile_id
>=
n_tiles
)
{
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
}
int
tc_col
=
th_id
/
4
;
int
tc_row
=
(
th_id
%
4
)
*
2
;
constexpr
int
tc_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
int
cur_n
=
warp_id
*
16
+
tc_col
;
constexpr
int
sh_stride
=
64
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
uint32_t
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_perm_ptr
);
uint32_t
vals
[
8
];
if
constexpr
(
has_perm
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
k_idx
=
tc_row
+
tc_offsets
[
i
];
uint32_t
src_k
=
sh_perm_int_ptr
[
k_idx
];
uint32_t
src_k_pos
=
src_k
%
pack_factor
;
uint32_t
b1_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
];
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
uint32_t
b2_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
+
8
];
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
vals
[
i
]
=
b1_cur_val
;
vals
[
4
+
i
]
=
b2_cur_val
;
}
}
else
{
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b2_vals
[
tile_ints
];
#pragma unroll
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
cur_int
=
cur_elem
/
pack_factor
;
int
cur_pos
=
cur_elem
%
pack_factor
;
vals
[
i
]
=
(
b1_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
b2_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
}
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
if
constexpr
(
has_perm
)
{
load_perm_to_shared
(
k_tile_id
);
}
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
repack_tile
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
wait_for_stage
();
}
n_tile_id
+=
repack_stages
;
}
}
}
}
// namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
((
size_k
/
pack_factor
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", pack_factor = "
,
pack_factor
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
==
size_n
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not size_n = "
,
size_n
);
// Verify device and strides
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
TORCH_CHECK
(
b_q_weight
.
dtype
()
==
at
::
kInt
,
"b_q_weight type is not kInt"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
TORCH_CHECK
(
perm
.
dtype
()
==
at
::
kInt
,
"perm type is not at::kInt"
);
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
b_q_weight
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
({
size_k
/
gptq_marlin
::
tile_size
,
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
// Detect if there is act_order
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
// Get ptrs
uint32_t
const
*
b_q_weight_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
b_q_weight
.
data_ptr
());
uint32_t
const
*
perm_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
perm
.
data_ptr
());
uint32_t
*
out_ptr
=
reinterpret_cast
<
uint32_t
*>
(
out
.
data_ptr
());
// Get dev info
int
dev
=
b_q_weight
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
false
)
{
}
CALL_IF
(
4
,
false
)
CALL_IF
(
4
,
true
)
CALL_IF
(
8
,
false
)
CALL_IF
(
8
,
true
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
,
", has_perm = "
,
has_perm
);
}
return
out
;
}
#endif
\ No newline at end of file
csrc/custom_marlin/gptq_marlin/ops.h
0 → 100644
View file @
877aec85
/**
* @Description :
* @Author : Azure
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 08:35:00
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/torch.h>
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
size_m_tensor
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int
sms
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
\ No newline at end of file
csrc/custom_marlin/setup.py
0 → 100644
View file @
877aec85
from
setuptools
import
setup
,
Extension
from
torch.utils
import
cpp_extension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
name
=
'vLLMMarlin'
,
ext_modules
=
[
CUDAExtension
(
'vLLMMarlin'
,
[
#'custom_gguf/dequant.cu',
'binding.cpp'
,
'gptq_marlin/gptq_marlin.cu'
,
'gptq_marlin/gptq_marlin_repack.cu'
,
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:
[
'-O3'
,
'--use_fast_math'
,
'-Xcompiler'
,
'-fPIC'
,
]
},
)
],
cmdclass
=
{
'build_ext'
:
BuildExtension
}
)
\ No newline at end of file
csrc/custom_marlin/test_cuda_graph.py
0 → 100644
View file @
877aec85
import
csv
import
torch
import
torch.nn
as
nn
import
vLLMMarlin
torch
.
set_grad_enabled
(
False
)
from
utils.marlin_utils
import
(
MarlinWorkspace
,
marlin_quantize
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MIN_THREAD_K
,
GPTQ_MARLIN_MAX_PARALLEL
,
)
def
setup_seed
(
seed
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
setup_seed
(
20241223
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
global_dtype
=
torch
.
bfloat16
global_device
=
torch
.
device
(
"cuda"
,
0
)
global_num_cases
:
int
=
int
(
50
)
torch
.
cuda
.
set_device
(
0
)
torch
.
backends
.
cudnn
.
enabled
=
True
torch
.
backends
.
cudnn
.
benchmark
=
True
max_batch_size
=
512
max_tp
=
8
L2_size
=
73728
*
1024
def
get_usable_mem
():
properties
=
torch
.
cuda
.
get_device_properties
(
global_device
)
#print(f"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB")
allocated_memory
=
torch
.
cuda
.
memory_allocated
(
global_device
)
#print(f"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB")
reserved_memory
=
torch
.
cuda
.
memory_reserved
(
global_device
)
#print(f"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB")
return
properties
.
total_memory
-
512
*
1024
**
2
-
allocated_memory
# - reserved_memory
def
exp_range
(
start
,
stop
,
step
=
2
):
now
=
start
while
now
<=
stop
:
yield
now
now
*=
step
def
timing
(
func
,
iters
,
epochs
=
100
):
#warmup
for
idx
in
range
(
iters
):
func
(
idx
)
torch
.
cuda
.
synchronize
()
cuda_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
cuda_graph
):
for
idx
in
range
(
iters
):
func
(
idx
)
for
_
in
range
(
2000
):
cuda_graph
.
replay
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
stream
=
torch
.
cuda
.
Stream
()
torch
.
cuda
.
synchronize
()
#with torch.cuda.stream(stream):
start_event
.
record
()
for
_
in
range
(
10
):
cuda_graph
.
replay
()
end_event
.
record
()
torch
.
cuda
.
synchronize
()
elapsed_time_ms0
=
start_event
.
elapsed_time
(
end_event
)
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
#with torch.cuda.stream(stream):
start_event
.
record
()
for
_
in
range
(
epochs
+
10
):
cuda_graph
.
replay
()
end_event
.
record
()
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
-
elapsed_time_ms0
#print(elapsed_time_ms0, elapsed_time_ms)
return
elapsed_time_ms
/
iters
/
epochs
class
LinearMarlin
(
nn
.
Linear
):
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
False
,
device
:
str
=
"cuda"
,
num_bits
:
int
=
4
,
# 4-bit/8-bit is supported
group_size
:
int
=
64
,
# -1, 32, 64, 128
act_order
:
bool
=
False
,
is_k_full
=
True
,
sms
=
-
1
,
# sms in GPU
**
kwargs
,
):
self
.
padding
=
False
assert
device
.
lower
()
!=
"cpu"
,
"Marlin quantized linear only supports GPU device"
if
in_features
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
or
out_features
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self
.
padding
=
True
self
.
orin_in_features
=
in_features
self
.
orin_out_features
=
out_features
in_features
=
(
in_features
+
GPTQ_MARLIN_MIN_THREAD_K
-
1
)
//
GPTQ_MARLIN_MIN_THREAD_K
*
GPTQ_MARLIN_MIN_THREAD_K
out_features
=
(
out_features
+
GPTQ_MARLIN_MIN_THREAD_N
-
1
)
//
GPTQ_MARLIN_MIN_THREAD_N
*
GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
super
().
__init__
(
in_features
,
out_features
,
bias
,
device
)
self
.
has_bias
=
bias
self
.
device
=
device
self
.
num_bits
=
num_bits
self
.
group_size
=
group_size
self
.
act_order
=
act_order
# TODO: optimize every shape GEMM
blocks_k
,
blocks_n
=
in_features
//
128
,
out_features
//
128
self
.
sms
=
sms
self
.
is_k_full
=
is_k_full
self
.
weight
.
requires_grad
=
False
self
.
weight
.
t_
()
# Pack Marlin linear
#w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
# self.weight, self.num_bits, self.group_size, self.act_order
#)
marlin_q_w
=
torch
.
randint
(
int
(
-
1e9
),
int
(
1e9
),
(
in_features
//
16
,
out_features
*
2
),
device
=
device
,
dtype
=
torch
.
int
)
marlin_s
=
torch
.
randn
((
in_features
//
64
,
out_features
),
device
=
device
)
self
.
workspace
=
MarlinWorkspace
(
self
.
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
self
.
device
)
self
.
marlin_q_w
=
marlin_q_w
self
.
marlin_s
=
marlin_s
self
.
g_idx
=
torch
.
empty
((
0
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
sort_indices
=
torch
.
empty
((
0
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
k
=
self
.
weight
.
shape
[
0
]
self
.
n
=
self
.
weight
.
shape
[
1
]
self
.
weight
=
None
"""
print(in_features, out_features)
print(marlin_q_w.shape)
print(marlin_q_w.dtype)
print(marlin_s.shape)
print(marlin_s.dtype)
print(self.workspace.scratch.shape)
print(self.workspace.scratch.dtype)
print(self.g_idx.shape)
print(self.g_idx.dtype)
print(self.sort_indices.shape)
print(self.sort_indices.dtype)
#print(w_ref.shape)
#print(w_ref.dtype)
"""
#w_ref = None
def
forward
(
self
,
x
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Only support input x as BF16 and FP16
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
orig_dtype
=
x
.
dtype
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
self
.
padding
:
padding_input
=
torch
.
empty
(
x
.
shape
[
0
],
self
.
in_features
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
padding_input
[:,:
self
.
orin_in_features
]
=
x
x
=
padding_input
marlin_s
=
self
.
marlin_s
.
to
(
x
.
dtype
)
#print(self.sms * ((orig_shape[0]+63)//64))
sms
=
self
.
sms
x
=
vLLMMarlin
.
gptq_marlin_gemm
(
x
,
self
.
marlin_q_w
,
marlin_s
,
self
.
g_idx
,
self
.
sort_indices
,
self
.
workspace
.
scratch
,
self
.
num_bits
,
bsz_tensor
,
x
.
shape
[
0
],
self
.
n
,
x
.
shape
[
-
1
],
sms
,
self
.
is_k_full
,
)
# TODO: don't padding bias
if
self
.
has_bias
:
x
=
x
+
self
.
bias
if
self
.
padding
:
x
=
x
[:,:
self
.
orin_out_features
]
orig_shape
[
-
1
]
=
self
.
orin_out_features
else
:
orig_shape
[
-
1
]
=
self
.
out_features
return
x
.
reshape
(
orig_shape
).
to
(
orig_dtype
)
def
benchLinearMarlin
(
input_dim
,
output_dim
):
#, out_file
print
(
"benchmarking MLP Marlin"
)
print
(
"-----------------------------------------------------------"
)
headers
=
[
"batch_size"
,
"tp"
,
"used_time"
,
"bandwidth GB/s"
,
"TFLOPS"
,
"cases"
,
"padding"
,
"sms"
]
print
(
" | "
.
join
(
headers
)
+
"
\n
"
)
rows
=
[]
for
batch_size
in
exp_range
(
1
,
64
):
for
tp
in
exp_range
(
1
,
max_tp
):
torch
.
cuda
.
empty_cache
()
if
output_dim
%
tp
!=
0
:
continue
cur_output_dim
=
output_dim
//
tp
modules
=
[]
inputs
=
[]
data_size
=
int
(
0.53125
*
input_dim
*
cur_output_dim
)
input_size
=
int
(
2
*
batch_size
*
input_dim
)
output_size
=
int
(
2
*
batch_size
*
cur_output_dim
)
usable_mem
=
get_usable_mem
()
-
2
*
input_dim
*
cur_output_dim
min_cases
=
max
(
global_num_cases
,
(
2
*
L2_size
)
//
(
data_size
+
input_size
))
cases
=
int
(
min
(
min_cases
,
(
usable_mem
*
0.8
)
//
(
data_size
+
input_size
)))
#print(usable_mem, data_size, input_size, cases)
bsz_tensor
=
torch
.
tensor
([
batch_size
],
device
=
global_device
,
dtype
=
torch
.
int32
)
if
cases
==
0
:
row
=
[
f
"
{
batch_size
}
"
,
"OOM"
,
"OOM"
,
"OOM"
,
"0"
,
"False"
]
rows
.
append
(
row
)
break
for
_
in
range
(
cases
):
modules
.
append
(
LinearMarlin
(
input_dim
,
cur_output_dim
,
sms
=
56
,
non_equal_division
=
False
).
to
(
device
=
global_device
).
eval
())
inputs
.
append
(
torch
.
randn
(
batch_size
,
1
,
input_dim
,
device
=
global_device
))
def
forward
(
case_id
):
modules
[
case_id
](
inputs
[
case_id
],
bsz_tensor
)
used_time
=
timing
(
forward
,
iters
=
cases
)
bandwidth
=
(
data_size
+
input_size
+
output_size
)
/
used_time
/
1e6
flops
=
2
*
batch_size
*
input_dim
*
cur_output_dim
tflops
=
flops
/
used_time
/
1e9
cur_sms
=
modules
[
0
].
sms
row
=
[
f
"
{
batch_size
}
"
,
f
"
{
tp
}
"
,
f
"
{
used_time
}
"
,
f
"
{
bandwidth
}
"
,
f
"
{
tflops
}
"
,
f
"
{
cases
}
"
,
modules
[
0
].
padding
,
cur_sms
]
rows
.
append
(
row
)
print
(
f
"
{
batch_size
}
"
,
f
"
{
tp
}
"
,
f
"
{
used_time
}
"
,
f
"
{
bandwidth
}
"
,
f
"
{
tflops
}
"
,
f
"
{
cases
}
"
,
modules
[
0
].
padding
,
cur_sms
)
"""
with open(out_file, 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(headers)
for row in rows:
csvwriter.writerow(row)
"""
"""
markdown_table = " | ".join(headers) + "
\n
"
markdown_table += " | ".join(["---"] * len(headers)) + "
\n
"
for row in rows:
markdown_table += " | ".join(row) + "
\n
"
print(markdown_table)
"""
#print("finish write file", out_file)
#print("-------------------------------------------------------------")
if
__name__
==
"__main__"
:
benchLinearMarlin
(
5120
,
3584
)
exit
(
0
)
max_batch
=
1
cur_batch
=
1
marlin_linear
=
LinearMarlin
(
5120
,
3584
)
input_tensor
=
torch
.
randn
(
max_batch
,
1
,
5120
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
bsz_tensor
=
torch
.
tensor
([
max_batch
],
device
=
"cuda"
,
dtype
=
torch
.
int32
)
out_truth
=
marlin_linear
(
input_tensor
,
bsz_tensor
)
print
(
out_truth
)
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
out_buf
=
marlin_linear
(
input_tensor
,
bsz_tensor
)
for
i
in
range
(
10000
):
g
.
replay
()
#torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3)
marlin_linear
=
LinearMarlin
(
5120
,
3584
)
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
out_buf
=
marlin_linear
(
input_tensor
,
bsz_tensor
)
new_input
=
torch
.
randn
(
cur_batch
,
1
,
5120
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
bsz_tensor
.
copy_
(
torch
.
tensor
([
cur_batch
],
device
=
"cuda"
,
dtype
=
torch
.
int32
))
new_out_truth
=
marlin_linear
(
new_input
,
bsz_tensor
)
input_tensor
[:
cur_batch
].
copy_
(
new_input
)
input_tensor
[
cur_batch
:]
=
0
g
.
replay
()
torch
.
cuda
.
synchronize
()
def
printMinMax
(
tensor
):
abs_tensor
=
torch
.
abs
(
tensor
)
min_val
=
torch
.
min
(
abs_tensor
)
max_val
=
torch
.
max
(
abs_tensor
)
min_indices
=
(
abs_tensor
==
min_val
).
nonzero
(
as_tuple
=
True
)
max_indices
=
(
abs_tensor
==
max_val
).
nonzero
(
as_tuple
=
True
)
print
(
f
"min:
{
min_val
.
item
()
}
"
)
print
(
f
"min idx:
{
min_indices
}
"
)
print
(
f
"max:
{
max_val
.
item
()
}
"
)
print
(
f
"max idx:
{
max_indices
}
"
)
print
(
out_buf
[:
cur_batch
].
shape
)
print
(
new_out_truth
.
shape
)
printMinMax
(
out_buf
[:
cur_batch
])
printMinMax
(
new_out_truth
)
#torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3)
csrc/custom_marlin/utils/__init__.py
0 → 100644
View file @
877aec85
csrc/custom_marlin/utils/format24.py
0 → 100644
View file @
877aec85
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import
torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
):
dst_rows
=
torch
.
arange
(
0
,
m
,
device
=
device
)[:,
None
].
repeat
(
1
,
meta_ncols
)
dst_cols
=
torch
.
arange
(
0
,
meta_ncols
,
device
=
device
).
repeat
(
m
,
1
)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x
=
64
group_y
=
32
if
meta_dtype
.
itemsize
==
2
else
16
dst_rows
=
(
dst_rows
//
group_x
*
group_x
+
(
dst_rows
%
2
)
*
2
+
(
dst_rows
%
8
)
//
4
+
((
dst_rows
%
group_y
)
%
4
)
//
2
*
32
+
((
dst_rows
%
group_x
)
//
8
)
*
4
)
topright
=
((
dst_rows
%
2
==
0
)
&
(
dst_cols
%
2
==
1
)).
to
(
torch
.
int8
)
bottomleft
=
((
dst_rows
%
2
==
1
)
&
(
dst_cols
%
2
==
0
)).
to
(
torch
.
int8
)
dst_rows
+=
topright
-
bottomleft
dst_cols
-=
topright
-
bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave
=
2
cols_maj
=
dst_cols
//
interleave
cols_min
=
dst_cols
%
interleave
return
(
cols_maj
*
m
*
interleave
+
dst_rows
*
interleave
+
cols_min
).
view
(
-
1
)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def
sparse_semi_structured_from_dense_cutlass
(
dense
):
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
dense
.
shape
device
=
dense
.
device
meta_dtype
=
torch
.
int8
if
dense
.
dtype
==
torch
.
int8
:
meta_dtype
=
torch
.
int32
elif
dense
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
int32
]:
meta_dtype
=
torch
.
int16
else
:
raise
RuntimeError
(
f
"Invalid datatype
{
dense
.
dtype
}
of dense matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
if
quadbits_per_meta_elem
not
in
(
4
,
8
):
raise
RuntimeError
(
"Invalid number of elements per meta element calculated"
)
if
meta_dtype
==
torch
.
int32
:
if
m
%
16
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 16"
)
else
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
# noqa: E501
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
dense_4
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m1
,
m2
,
m3
=
(
dense_4
!=
0
).
unbind
(
-
1
)
else
:
ksparse
=
2
dense_2
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m2
=
m1
,
m3
=
(
dense_2
!=
0
).
unbind
(
-
1
)
meta_ncols
=
k
//
(
ksparse
*
quadbits_per_meta_elem
)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0
=
m0
&
m1
expr1
=
~
m0
&
m1
expr2
=
~
m0
&
~
m1
bit0
=
expr1
bit1
=
expr2
bit2
=
expr0
|
expr2
|
m3
bit3
=
expr1
|
~
m1
idxs0
=
bit0
|
(
bit1
.
to
(
torch
.
int64
)
<<
1
)
idxs1
=
bit2
|
(
bit3
.
to
(
torch
.
int64
)
<<
1
)
if
dense
.
dtype
!=
torch
.
float
:
sparse0
=
dense_4
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
))
# type: ignore[possibly-undefined]
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
(
(
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
# Reorder meta tensor elements.
meta_reordered
=
meta
.
new_empty
(
(
m
*
meta_ncols
,
))
# type: ignore[possibly-undefined]
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta_reordered
.
scatter_
(
0
,
meta_offsets
,
meta
.
view
(
-
1
))
return
(
sparse
,
meta_reordered
.
view
(
m
,
meta_ncols
))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def
sparse_semi_structured_to_dense_cutlass
(
sparse
,
meta_reordered
):
if
sparse
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional sparse tensor, got
{
sparse
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
sparse
.
shape
device
=
sparse
.
device
if
meta_reordered
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional meta tensor, got
{
meta_reordered
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
if
meta_reordered
.
device
!=
device
:
raise
RuntimeError
(
f
"Expected meta matrix to be on
{
device
}
device, got matrix on
{
meta_reordered
.
device
}
device"
# noqa: E501
)
meta_dtype
=
meta_reordered
.
dtype
if
meta_dtype
not
in
(
torch
.
int16
,
torch
.
int32
):
raise
RuntimeError
(
f
"Invalid datatype
{
meta_dtype
}
of meta matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
ksparse
=
4
if
sparse
.
dtype
!=
torch
.
float
else
2
meta_nrows
,
meta_ncols
=
meta_reordered
.
shape
if
meta_nrows
!=
m
:
raise
RuntimeError
(
f
"Number of rows of meta matrix
{
meta_nrows
}
must be equal to number of columns of spase matrix
{
m
}
"
# noqa: E501
)
if
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
!=
2
*
k
:
raise
RuntimeError
(
f
"Number of columns of sparse matrix
{
k
}
different from the
{
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
//
2
}
, "
# noqa: E501
"expected according to the number of columns of meta matrix"
)
# Undo meta tensor elements reordering.
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta
=
torch
.
gather
(
meta_reordered
.
view
(
-
1
),
0
,
meta_offsets
).
view
(
m
,
meta_ncols
)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2
=
torch
.
empty
(
(
m
,
meta_ncols
,
2
*
quadbits_per_meta_elem
),
dtype
=
meta_dtype
,
device
=
device
,
)
if
quadbits_per_meta_elem
==
4
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
elif
quadbits_per_meta_elem
==
8
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
meta_2
[:,
:,
8
]
=
(
meta
>>
16
)
&
0b11
meta_2
[:,
:,
9
]
=
(
meta
>>
18
)
&
0b11
meta_2
[:,
:,
10
]
=
(
meta
>>
20
)
&
0b11
meta_2
[:,
:,
11
]
=
(
meta
>>
22
)
&
0b11
meta_2
[:,
:,
12
]
=
(
meta
>>
24
)
&
0b11
meta_2
[:,
:,
13
]
=
(
meta
>>
26
)
&
0b11
meta_2
[:,
:,
14
]
=
(
meta
>>
28
)
&
0b11
meta_2
[:,
:,
15
]
=
(
meta
>>
30
)
&
0b11
dense_offsets
=
meta_2
.
view
(
-
1
)
+
(
torch
.
arange
(
0
,
2
*
m
*
k
//
ksparse
,
device
=
device
)
*
4
).
view
(
-
1
,
1
).
repeat
(
1
,
2
).
view
(
-
1
)
dense
=
torch
.
zeros
((
m
*
2
*
k
,
),
dtype
=
sparse
.
dtype
,
device
=
device
)
if
sparse
.
dtype
!=
torch
.
float
:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense
.
scatter_
(
0
,
dense_offsets
,
sparse
.
reshape
(
-
1
))
else
:
dense
.
view
(
torch
.
half
).
scatter_
(
0
,
dense_offsets
,
sparse
.
view
(
torch
.
half
).
view
(
-
1
))
return
dense
.
view
(
m
,
2
*
k
)
def
mask_creator
(
tensor
):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N
=
2
M
=
4
mask
=
None
# for i, tensor in enumerate(tensors):
if
tensor
.
numel
()
%
M
!=
0
:
raise
ValueError
(
f
"Tensor of size
{
tensor
.
shape
}
can't be evenly divided into "
f
"
{
M
}
groups"
)
num_groups
=
tensor
.
numel
()
//
M
# N:M sparsity for linear layers
tensor_temp
=
tensor
.
detach
().
abs
().
reshape
(
num_groups
,
M
)
index
=
torch
.
argsort
(
tensor_temp
,
dim
=
1
)[:,
:
int
(
M
-
N
)]
w_b
=
torch
.
ones
(
tensor_temp
.
shape
,
device
=
tensor_temp
.
device
)
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
\ No newline at end of file
csrc/custom_marlin/utils/marlin_24_perms.py
0 → 100644
View file @
877aec85
'''
Date: 2024-11-08 02:46:07
LastEditors: djw
LastEditTime: 2024-11-08 02:46:41
'''
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_24_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_24_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
\ No newline at end of file
csrc/custom_marlin/utils/marlin_perms.py
0 → 100644
View file @
877aec85
'''
Date: 2024-11-08 02:46:47
LastEditors: djw
LastEditTime: 2024-11-08 02:46:55
'''
"""This file is used for /tests and /benchmarks"""
from
typing
import
Dict
,
List
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
marlin_scale_perm
:
Dict
[
int
,
List
[
int
]]
=
{}
marlin_scale_perm_single
:
Dict
[
int
,
List
[
int
]]
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
\ No newline at end of file
csrc/custom_marlin/utils/marlin_utils.py
0 → 100644
View file @
877aec85
"""This file is used for /tests and /benchmarks"""
import
random
import
numpy
import
torch
from
.format24
import
(
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
from
.marlin_perms
import
(
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
,
dequantize_weights
)
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
scale_perm_single
):
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_scale_perm
[
num_bits
],
marlin_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
,
device
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
)
\ No newline at end of file
csrc/custom_marlin/utils/quant_utils.py
0 → 100644
View file @
877aec85
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
g_idx
=
torch
.
zeros
((
k_size
,
),
dtype
=
torch
.
int32
)
for
i
in
range
(
k_size
):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
# Function: Dequantize quantized weights
def
dequantize_weights
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
=
4
,
group_size
=
128
,
device
=
'cuda:0'
):
# Create a tensor for bitwise right shift operation
wf
=
torch
.
tensor
(
list
(
range
(
0
,
32
,
bits
)),
dtype
=
torch
.
int32
,
device
=
device
).
unsqueeze
(
0
)
# Apply bitwise right shift and convert qzeros to the appropriate type
zeros
=
torch
.
bitwise_right_shift
(
torch
.
unsqueeze
(
qzeros
,
2
).
expand
(
-
1
,
-
1
,
32
//
bits
),
wf
.
unsqueeze
(
0
)).
to
(
torch
.
int16
if
bits
==
8
else
torch
.
int8
)
torch
.
bitwise_and
(
zeros
,
(
2
**
bits
)
-
1
,
out
=
zeros
)
# Reshape the zeros tensor
zeros
=
zeros
+
1
zeros
=
zeros
.
reshape
(
-
1
,
1
,
zeros
.
shape
[
1
]
*
zeros
.
shape
[
2
])
# Reshape the scales tensor
scales
=
scales
.
reshape
(
-
1
,
1
,
scales
.
shape
[
-
1
])
# Similar bitwise right shift operation for qweight and reshape
weight
=
torch
.
bitwise_right_shift
(
torch
.
unsqueeze
(
qweight
,
1
).
expand
(
-
1
,
32
//
bits
,
-
1
),
wf
.
unsqueeze
(
-
1
)).
to
(
torch
.
int16
if
bits
==
8
else
torch
.
int8
)
torch
.
bitwise_and
(
weight
,
(
2
**
bits
)
-
1
,
out
=
weight
)
weight
=
weight
.
reshape
(
-
1
,
group_size
,
weight
.
shape
[
2
])
# Apply dequantization formula and reshape the final weight
weight
=
(
scales
*
(
weight
-
zeros
))
weight
=
weight
.
reshape
(
weight
.
shape
[
0
]
*
weight
.
shape
[
1
],
weight
.
shape
[
2
])
# Return the transposed weight
return
weight
.
transpose
(
0
,
1
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
view
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
dtype
=
torch
.
int32
)
# Sort based on g_idx
g_idx
=
g_idx
[
sort_indices
].
contiguous
()
q_w
=
q_w
[
sort_indices
,
:].
contiguous
()
return
(
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
sort_indices
.
to
(
device
=
orig_device
),
)
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_k
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
//
pack_factor
,
size_n
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[
i
::
pack_factor
,
:]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
def
gptq_unpack
(
q_res
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
32
//
num_bits
assert
size_k
%
pack_factor
==
0
orig_device
=
q_res
.
device
q_res
=
q_res
.
cpu
().
numpy
()
q_w
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_w
[
i
::
pack_factor
,
:]
=
(
q_res
>>
(
num_bits
*
i
))
&
((
1
<<
num_bits
)
-
1
)
q_w
=
torch
.
from_numpy
(
q_w
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_w
\ No newline at end of file
ktransformers
/ktransformers_ext/CMakeLists.txt
→
csrc
/ktransformers_ext/CMakeLists.txt
View file @
877aec85
...
...
@@ -3,8 +3,16 @@ project(cpuinfer_ext VERSION 0.1.0)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-O3 -ffast-math"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-O3 -ffast-math -fopenmp"
)
add_compile_definitions
(
_GLIBCXX_USE_CXX11_ABI=
${
_GLIBCXX_USE_CXX11_ABI
}
)
set
(
CMAKE_BUILD_TYPE
"Release"
)
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp")
# set(CMAKE_BUILD_TYPE "Debug")
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
include
(
CheckCXXCompilerFlag
)
set
(
CMAKE_POSITION_INDEPENDENT_CODE ON
)
...
...
@@ -30,7 +38,7 @@ if (NOT MSVC)
option
(
LLAMA_F16C
"llama: enable F16C"
OFF
)
endif
()
option
(
LLAMA_AVX512_FANCY_SIMD
"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI"
OFF
)
option
(
KTRANSFORMERS_USE_CUDA
"ktransformers: use CUDA"
O
FF
)
option
(
KTRANSFORMERS_USE_CUDA
"ktransformers: use CUDA"
O
N
)
option
(
KTRANSFORMERS_USE_MUSA
"ktransformers: use MUSA"
OFF
)
option
(
KTRANSFORMERS_USE_ROCM
"ktransformers: use ROCM"
OFF
)
...
...
@@ -147,6 +155,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
endif
()
else
()
if
(
LLAMA_NATIVE
)
list
(
APPEND ARCH_FLAGS -mfma -mavx -mavx2
)
list
(
APPEND ARCH_FLAGS -march=native
)
endif
()
if
(
LLAMA_F16C
)
...
...
@@ -172,6 +181,7 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
list
(
APPEND ARCH_FLAGS -mavx512vnni
)
endif
()
if
(
LLAMA_AVX512_FANCY_SIMD
)
message
(
STATUS
"AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled"
)
list
(
APPEND ARCH_FLAGS -mavx512vl
)
list
(
APPEND ARCH_FLAGS -mavx512bw
)
list
(
APPEND ARCH_FLAGS -mavx512dq
)
...
...
@@ -238,9 +248,18 @@ if (WIN32)
include_directories
(
"$ENV{CUDA_PATH}/include"
)
add_compile_definitions
(
KTRANSFORMERS_USE_CUDA=1
)
elseif
(
UNIX
)
if
(
KTRANSFORMERS_USE_CUDA
)
find_package
(
CUDA REQUIRED
)
include_directories
(
"
${
CUDA_INCLUDE_DIRS
}
"
)
if
(
NOT KTRANSFORMERS_USE_MUSA
)
# find_package(CUDA REQUIRED)
# include_directories("${CUDA_INCLUDE_DIRS}")
include
(
CheckLanguage
)
check_language
(
CUDA
)
if
(
CMAKE_CUDA_COMPILER
)
message
(
STATUS
"CUDA detected"
)
find_package
(
CUDAToolkit REQUIRED
)
include_directories
(
${
CUDAToolkit_INCLUDE_DIRS
}
)
endif
()
message
(
STATUS
"enabling CUDA"
)
enable_language
(
CUDA
)
add_compile_definitions
(
KTRANSFORMERS_USE_CUDA=1
)
endif
()
...
...
@@ -278,19 +297,35 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/llamafile SOURCE_DIR3
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../third_party/llamafile SOURCE_DIR4
)
aux_source_directory
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/operators/kvcache SOURCE_DIR5
)
set
(
ALL_SOURCES
${
SOURCE_DIR1
}
${
SOURCE_DIR2
}
${
SOURCE_DIR3
}
${
SOURCE_DIR4
}
${
SOURCE_DIR5
}
)
message
(
STATUS
"ALL_SOURCES:
${
ALL_SOURCES
}
"
)
file
(
GLOB_RECURSE FMT_SOURCES
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.cpp"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.hpp"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/*.h"
)
add_custom_target
(
format
COMMAND clang-format
-i
-style=file
${
FMT_SOURCES
}
COMMENT
"Running clang-format on all source files"
)
add_library
(
llamafile STATIC
${
SOURCE_DIR4
}
)
message
(
STATUS
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
STATUS
"ARCH_FLAGS:
${
ARCH_FLAGS
}
"
)
pybind11_add_module
(
${
PROJECT_NAME
}
MODULE
${
ALL_SOURCES
}
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE llama
)
if
(
WIN32
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_PATH}/lib/x64/cudart.lib"
)
#CUDA::cudart
elseif
(
UNIX
)
if
(
KTRANSFORMERS_USE_CUDA
)
if
(
NOT DEFINED ENV{CUDA_HOME} OR
"$ENV{CUDA_HOME}"
STREQUAL
""
)
set
(
ENV{CUDA_HOME}
"/usr/local/cuda"
)
endif
()
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"$ENV{CUDA_HOME}/lib64/libcudart.so"
)
if
(
NOT KTRANSFORMERS_USE_MUSA
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
"
${
CUDAToolkit_LIBRARY_DIR
}
/libcudart.so"
)
endif
()
if
(
KTRANSFORMERS_USE_ROCM
)
add_compile_definitions
(
USE_HIP=1
)
...
...
@@ -304,21 +339,28 @@ endif()
# Define the USE_NUMA option
option
(
USE_NUMA
"Disable NUMA support"
OFF
)
# Check if the USE_NUMA environment variable is set
if
(
DEFINED ENV{USE_NUMA}
)
set
(
USE_NUMA ON
)
endif
()
if
(
USE_NUMA
)
if
(
USE_NUMA
)
message
(
STATUS
"NUMA support is enabled"
)
else
()
message
(
STATUS
"NUMA support is disabled"
)
endif
()
find_library
(
NUMA_LIBRARY NAMES numa
)
if
(
NUMA_LIBRARY AND USE_NUMA
)
if
(
NUMA_LIBRARY AND USE_NUMA
)
message
(
STATUS
"NUMA library found:
${
NUMA_LIBRARY
}
- enabling NUMA support"
)
target_link_libraries
(
${
PROJECT_NAME
}
PRIVATE
${
NUMA_LIBRARY
}
)
target_compile_definitions
(
${
PROJECT_NAME
}
PRIVATE USE_NUMA
)
else
()
message
(
STATUS
"NUMA library not found or user not set USE_NUMA - disabling NUMA support"
)
endif
()
if
(
USE_NUMA
)
message
(
FATAL_ERROR
"NUMA library not found - maybe sudo apt install libnuma-dev"
)
else
()
message
(
STATUS
"NUMA library not found or user not set USE_NUMA - disabling NUMA support"
)
endif
()
endif
()
\ No newline at end of file
ktransformers
/ktransformers_ext/bench/bench_attention.py
→
csrc
/ktransformers_ext/bench/bench_attention.py
View file @
877aec85
File moved
ktransformers
/ktransformers_ext/bench/bench_attention_torch.py
→
csrc
/ktransformers_ext/bench/bench_attention_torch.py
View file @
877aec85
File moved
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment