Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f0bbc5db
Commit
f0bbc5db
authored
Feb 13, 2025
by
Bartlomiej Kocot
Browse files
[CK TILE] GEMM with packed i4
parent
0e5e29c4
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
479 additions
and
92 deletions
+479
-92
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+15
-13
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+60
-8
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+151
-5
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
.../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+174
-21
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+13
-8
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+12
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+29
-16
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+25
-15
No files found.
include/ck_tile/host/host_tensor.hpp
View file @
f0bbc5db
...
...
@@ -281,18 +281,18 @@ struct HostTensor
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
HostTensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
get_element_space_size
())
HostTensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
get_element_space_size
())
{
}
template
<
typename
X
,
typename
Y
>
HostTensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
get_element_space_size
())
:
mDesc
(
lens
,
strides
),
mData
(
get_element_space_size
())
{
}
template
<
typename
Lengths
>
HostTensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
get_element_space_size
())
HostTensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
get_element_space_size
())
{
}
...
...
@@ -302,7 +302,7 @@ struct HostTensor
{
}
HostTensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
get_element_space_size
())
{}
HostTensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
get_element_space_size
())
{}
template
<
typename
OutT
>
HostTensor
<
OutT
>
CopyAsType
()
const
...
...
@@ -340,7 +340,11 @@ struct HostTensor
std
::
size_t
get_element_size
()
const
{
return
mDesc
.
get_element_size
();
}
std
::
size_t
get_element_space_size
()
const
{
return
mDesc
.
get_element_space_size
();
}
std
::
size_t
get_element_space_size
()
const
{
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
T
>>::
PackedSize
;
return
mDesc
.
get_element_space_size
()
/
PackedSize
;
}
std
::
size_t
get_element_space_size_in_bytes
()
const
{
...
...
@@ -463,29 +467,27 @@ struct HostTensor
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
T
>>::
PackedSize
;
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
PackedSize
;
}
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
return
mData
[
GetOffsetFromMultiIndex
(
is
...)];
}
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
return
mData
[
GetOffsetFromMultiIndex
(
is
...)];
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
GetOffsetFromMultiIndex
(
idx
)];
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
return
mData
[
GetOffsetFromMultiIndex
(
idx
)];
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
const
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
AccDataType
v_a
;
AccDataType
v_b
;
if
constexpr
(
std
::
is_same_v
<
ADataType
,
pk_int4_t
>
)
{
const
pk_int4_t
pk_val
=
a_element_op
(
a_m_k
(
m
,
k
));
const
fp32x2_t
fp32_val
=
pk_int4_t_to_fp32x2_t
(
pk_val
);
if
(
k
%
2
==
1
)
v_a
=
fp32_val
.
hi
;
else
v_a
=
fp32_val
.
lo
;
}
else
{
v_a
=
ck_tile
::
type_convert
<
AccDataType
>
(
a_element_op
(
a_m_k
(
m
,
k
)));
}
if
constexpr
(
std
::
is_same_v
<
BDataType
,
pk_int4_t
>
)
{
const
pk_int4_t
pk_val
=
b_element_op
(
b_k_n
(
k
,
n
));
const
fp32x2_t
fp32_val
=
pk_int4_t_to_fp32x2_t
(
pk_val
);
if
(
k
%
2
==
1
)
v_b
=
fp32_val
.
hi
;
else
v_b
=
fp32_val
.
lo
;
}
else
{
v_b
=
ck_tile
::
type_convert
<
AccDataType
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
}
v_acc
+=
v_a
*
v_b
;
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
...
...
@@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType
acc
=
0.0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
constexpr
index_t
packed_size_a
=
ck_tile
::
numeric_traits
<
ADataType
>::
PackedSize
;
constexpr
index_t
packed_size_b
=
ck_tile
::
numeric_traits
<
BDataType
>::
PackedSize
;
// Adjust indexing based on matrix layout
int
a_index
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideA
+
k
...
...
@@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
AccDataType
v_a
;
AccDataType
v_b
;
if
constexpr
(
std
::
is_same_v
<
ADataType
,
pk_int4_t
>
)
{
const
fp32x2_t
fp32_val
=
pk_int4_t_to_fp32x2_t
(
A
[
a_index
/
packed_size_a
]);
if
(
k
%
2
==
1
)
v_a
=
fp32_val
.
hi
;
else
v_a
=
fp32_val
.
lo
;
}
else
{
v_a
=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
]);
}
if
constexpr
(
std
::
is_same_v
<
BDataType
,
pk_int4_t
>
)
{
const
fp32x2_t
fp32_val
=
pk_int4_t_to_fp32x2_t
(
B
[
b_index
/
packed_size_b
]);
if
(
k
%
2
==
1
)
v_b
=
fp32_val
.
hi
;
else
v_b
=
fp32_val
.
lo
;
}
else
{
v_b
=
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
acc
+=
v_a
*
v_b
;
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,20 +9,166 @@
namespace
ck_tile
{
namespace
element_wise
{
#if 0
// Fast int4x4 to fp16x8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
CK_TILE_DEVICE
fp16x4_t
i4_to_half4
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
int
lo
;
int
hi
;
// Extract the two int4 at low bit and create two fp16 number.
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
lo
)
:
"v"
(
q
),
"v"
(
LO
),
"v"
(
EX
));
// Extract the two int4 at hight bit and create two fp16 number.
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
hi
)
:
"v"
(
q
),
"v"
(
HI
),
"v"
(
EX
));
const
int
SUB
=
0xE408E408
;
// half2 {-1032, -1032}
const
int
MUL
=
0x2c002c00
;
// half2 {1 / 16, 1 / 16}
const
int
ADD
=
0xd480d480
;
// half2 {-72, -72}
fp16x4_t
res
;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
res
.
lo
)
:
"v"
(
bit_cast
<
fp16x2_t
>
(
lo
)),
"v"
(
bit_cast
<
fp16x2_t
>
(
SUB
)));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3"
:
"=v"
(
res
.
hi
)
:
"v"
(
bit_cast
<
fp16x2_t
>
(
hi
)),
"v"
(
bit_cast
<
fp16x2_t
>
(
MUL
)),
"v"
(
bit_cast
<
fp16x2_t
>
(
ADD
)));
return
res
;
}
CK_TILE_DEVICE
fp16x4_t
i4_to_half4_scale
(
int
q
,
const
fp16x2_t
&
scale
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
int
lo
;
int
hi
;
// Extract the two int4 at low bit and create two fp16 number.
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
lo
)
:
"v"
(
q
),
"v"
(
LO
),
"v"
(
EX
));
// Extract the two int4 at hight bit and create two fp16 number.
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
hi
)
:
"v"
(
q
),
"v"
(
HI
),
"v"
(
EX
));
const
int
SUB
=
0xE408E408
;
// half2 {-1032, -1032}
const
int
MUL
=
0x2c002c00
;
// half2 {1 / 16, 1 / 16}
const
int
ADD
=
0xd480d480
;
// half2 {-72, -72}
fp16x4_t
res
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
res
.
lo
)
:
"v"
(
bit_cast
<
fp16x2_t
>
(
lo
)),
"v"
(
bit_cast
<
fp16x2_t
>
(
SUB
)));
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3"
:
"=v"
(
res
.
hi
)
:
"v"
(
bit_cast
<
fp16x2_t
>
(
hi
)),
"v"
(
bit_cast
<
fp16x2_t
>
(
MUL
)),
"v"
(
bit_cast
<
fp16x2_t
>
(
ADD
)));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
lo
)
:
"v"
(
res
.
lo
),
"v"
(
scale
));
asm
volatile
(
"v_pk_mul_f16 %0, %1, %2"
:
"=v"
(
res
.
hi
)
:
"v"
(
res
.
hi
),
"v"
(
scale
));
return
res
;
}
CK_TILE_DEVICE
bf16x4_t
i4_to_bhalf4
(
int
q
)
{
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388616.
f
;
fp32_intermediates
[
1
]
-=
8388616.
f
;
fp32_intermediates
[
2
]
-=
8388616.
f
;
fp32_intermediates
[
3
]
-=
8388616.
f
;
bf16x4_t
res
;
res
.
lo
=
bit_cast
<
bf16x2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
1
],
fp32_intermediates_casted
[
0
],
0x7632
));
res
.
hi
=
bit_cast
<
bf16x2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
3
],
fp32_intermediates_casted
[
2
],
0x7632
));
return
res
;
}
struct
PassThroughPack8
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
fp16x8_t
&
y
,
const
pk_int4x4_t
&
x
)
const
{
y
.
lo
=
i4_to_half4
(
bit_cast
<
int
>
(
x
));
y
.
hi
=
i4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
}
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
bf16x8_t
&
y
,
const
pk_int4x4_t
&
x
)
const
{
y
.
lo
=
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
y
.
hi
=
i4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
struct
DequantPack8
{
template
<
typename
Y
,
typename
X
,
typename
Z
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
,
const
Z
&
z
)
const
;
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
fp16x8_t
&
y
,
const
pk_int4x4_t
&
x
,
const
fp16x2_t
&
z
)
const
{
y
.
lo
=
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
),
z
);
y
.
hi
=
i4_to_half4_scale
(
bit_cast
<
int
>
(
x
)
>>
8
,
z
);
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
#if 0
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<
half
2_t>(t);
y = type_convert<
fp16x
2_t>(t);
}
#endif
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
fp16x2_t
&
y
,
const
pk_int4_t
&
x
)
const
{
uint8_t
x_u8
=
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
y
.
lo
=
type_convert
<
half_t
>
(
x_l
);
y
.
hi
=
type_convert
<
half_t
>
(
x_h
);
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
#endif
struct
PassThrough
{
...
...
include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
View file @
f0bbc5db
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace
ck_tile
{
...
...
@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
...
...
@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
using
BWarpTileDistr
=
remove_cvref_t
<
decltype
(
make_static_tile_distribution
(
typename
WarpGemm
::
BWarpDstrEncoding
{}))
>
;
using
AWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
AWarpTileDistr
{}))
>
;
using
BWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BWarpTileDistr
{}))
>
;
using
AWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
ComputeDataType
>
(
AWarpTileDistr
{}))
>
;
using
BWarpTile
=
remove_cvref_t
<
decltype
(
make_static_distributed_tensor
<
ComputeDataType
>
(
BWarpTileDistr
{}))
>
;
// TODO: Should we have two policies? Interwave & Intrawave ??
static
constexpr
index_t
InterWaveSchedulingMacClusters
=
1
;
...
...
@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
public:
using
Traits
=
GemmTraits_
<
Problem_
,
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Traits
::
ComputeDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
using
WarpGemm
=
remove_cvref_t
<
typename
Traits
::
WarpGemm
>
;
...
...
@@ -105,6 +108,11 @@ struct BlockUniversalGemmAsBsCr
static
constexpr
auto
Scheduler
=
Traits
::
Scheduler
;
static
constexpr
index_t
APackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
ADataType
>>::
PackedSize
;
static
constexpr
index_t
BPackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
BDataType
>>::
PackedSize
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
...
...
@@ -208,6 +216,8 @@ struct BlockUniversalGemmAsBsCr
});
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
constexpr
auto
c_warp_y_lengths
=
...
...
@@ -217,10 +227,58 @@ struct BlockUniversalGemmAsBsCr
// hot loop:
static_for
<
0
,
GemmTraits
::
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
const
auto
a_warp_tile
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
AWarpTensor
a_warp_tile
;
if
constexpr
(
std
::
is_same_v
<
ADataType
,
pk_int4_t
>
)
{
constexpr
index_t
UnaryOpSize
=
8
;
const
element_wise
::
PassThroughPack8
elementwise_op
{};
constexpr
index_t
thread_buffer_size
=
AWarpTensor
::
get_thread_buffer_size
()
/
UnaryOpSize
;
const
auto
in_dstr_tensors
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_assert
(
GemmTraits
::
AWarpTile
::
get_thread_buffer_size
()
%
UnaryOpSize
==
0
);
using
ComputeVectorType
=
ComputeDataType
__attribute__
((
ext_vector_type
(
UnaryOpSize
)));
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
elementwise_op
(
a_warp_tile
.
get_thread_buffer
()
.
template
get_as
<
ComputeVectorType
>()(
i
),
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
pk_int4x4_t
>()[
i
]);
});
}
else
{
a_warp_tile
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
}
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
const
auto
b_warp_tile
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
BWarpTensor
b_warp_tile
;
if
constexpr
(
std
::
is_same_v
<
BDataType
,
pk_int4_t
>
)
{
constexpr
index_t
UnaryOpSize
=
8
;
const
element_wise
::
PassThroughPack8
elementwise_op
{};
const
auto
in_dstr_tensors
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
constexpr
index_t
thread_buffer_size
=
BWarpTensor
::
get_thread_buffer_size
()
/
UnaryOpSize
;
static_assert
(
GemmTraits
::
BWarpTile
::
get_thread_buffer_size
()
%
UnaryOpSize
==
0
);
using
ComputeVectorType
=
ComputeDataType
__attribute__
((
ext_vector_type
(
UnaryOpSize
)));
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
elementwise_op
(
b_warp_tile
.
get_thread_buffer
()
.
template
get_as
<
ComputeVectorType
>()(
i
),
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
pk_int4x4_t
>()[
i
]);
});
}
else
{
b_warp_tile
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
}
// read C warp tensor from C block tensor-
CWarpTensor
c_warp_tensor
;
...
...
@@ -342,11 +400,59 @@ struct BlockUniversalGemmAsBsCr
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
if
constexpr
(
std
::
is_same_v
<
ADataType
,
pk_int4_t
>
)
{
constexpr
index_t
UnaryOpSize
=
8
;
const
element_wise
::
PassThroughPack8
elementwise_op
{};
constexpr
index_t
thread_buffer_size
=
GemmTraits
::
AWarpTile
::
get_thread_buffer_size
()
/
UnaryOpSize
;
const
auto
in_dstr_tensors
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_assert
(
GemmTraits
::
AWarpTile
::
get_thread_buffer_size
()
%
UnaryOpSize
==
0
);
using
ComputeVectorType
=
ComputeDataType
__attribute__
((
ext_vector_type
(
UnaryOpSize
)));
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
elementwise_op
(
a_warp_tiles_
(
mIter
)(
kIter
)
.
get_thread_buffer
()
.
template
get_as
<
ComputeVectorType
>()(
i
),
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
pk_int4x4_t
>()[
i
]);
});
}
else
{
a_warp_tiles_
(
mIter
)(
kIter
)
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
}
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
if
constexpr
(
std
::
is_same_v
<
BDataType
,
pk_int4_t
>
)
{
constexpr
index_t
UnaryOpSize
=
8
;
const
element_wise
::
PassThroughPack8
elementwise_op
{};
constexpr
index_t
thread_buffer_size
=
GemmTraits
::
BWarpTile
::
get_thread_buffer_size
()
/
UnaryOpSize
;
const
auto
in_dstr_tensors
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
static_assert
(
GemmTraits
::
BWarpTile
::
get_thread_buffer_size
()
%
UnaryOpSize
==
0
);
using
ComputeVectorType
=
ComputeDataType
__attribute__
((
ext_vector_type
(
UnaryOpSize
)));
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
elementwise_op
(
b_warp_tiles_
(
nIter
)(
kIter
)
.
get_thread_buffer
()
.
template
get_as
<
ComputeVectorType
>()(
i
),
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
pk_int4x4_t
>()[
i
]);
});
}
else
{
b_warp_tiles_
(
nIter
)(
kIter
)
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
}
});
});
}
...
...
@@ -504,12 +610,59 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for
<
0
,
KInnerLoopIter
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block window
load_tile
(
a_warp_tiles_
(
mIter
)(
kIter
),
a_warp_windows
(
mIter
)(
kIter
));
if
constexpr
(
std
::
is_same_v
<
ADataType
,
pk_int4_t
>
)
{
constexpr
index_t
UnaryOpSize
=
8
;
const
element_wise
::
PassThroughPack8
elementwise_op
{};
constexpr
index_t
thread_buffer_size
=
GemmTraits
::
AWarpTile
::
get_thread_buffer_size
()
/
UnaryOpSize
;
const
auto
in_dstr_tensors
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
static_assert
(
GemmTraits
::
AWarpTile
::
get_thread_buffer_size
()
%
UnaryOpSize
==
0
);
using
ComputeVectorType
=
ComputeDataType
__attribute__
((
ext_vector_type
(
UnaryOpSize
)));
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
elementwise_op
(
a_warp_tiles_
(
mIter
)(
kIter
)
.
get_thread_buffer
()
.
template
get_as
<
ComputeVectorType
>()(
i
),
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
pk_int4x4_t
>()[
i
]);
});
}
else
{
a_warp_tiles_
(
mIter
)(
kIter
)
=
load_tile
(
a_warp_windows
(
mIter
)(
kIter
));
}
});
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
load_tile
(
b_warp_tiles_
(
nIter
)(
kIter
),
b_warp_windows
(
nIter
)(
kIter
));
if
constexpr
(
std
::
is_same_v
<
BDataType
,
pk_int4_t
>
)
{
constexpr
index_t
UnaryOpSize
=
8
;
const
element_wise
::
PassThroughPack8
elementwise_op
{};
constexpr
index_t
thread_buffer_size
=
GemmTraits
::
BWarpTile
::
get_thread_buffer_size
()
/
UnaryOpSize
;
const
auto
in_dstr_tensors
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
static_assert
(
GemmTraits
::
BWarpTile
::
get_thread_buffer_size
()
%
UnaryOpSize
==
0
);
using
ComputeVectorType
=
ComputeDataType
__attribute__
((
ext_vector_type
(
UnaryOpSize
)));
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
elementwise_op
(
b_warp_tiles_
(
nIter
)(
kIter
)
.
get_thread_buffer
()
.
template
get_as
<
ComputeVectorType
>()(
i
),
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
pk_int4x4_t
>()[
i
]);
});
}
else
{
b_warp_tiles_
(
nIter
)(
kIter
)
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
}
});
});
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
f0bbc5db
...
...
@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
APackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
ADataType
>>::
PackedSize
;
static
constexpr
index_t
BPackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
BDataType
>>::
PackedSize
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
...
...
@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
/
APackedSize
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
/
BPackedSize
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
B_LDS_Write_Inst_Num
;
...
...
@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
A_LDS_Read_Width
*
sizeof
(
ADataType
)
/
APackedSize
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
B_LDS_Read_Width
*
sizeof
(
BDataType
)
/
BPackedSize
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
f0bbc5db
...
...
@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
typename
Problem
::
ADataType
>>::
PackedSize
;
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
/
PackedSize
;
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
typename
Problem
::
BDataType
>>::
PackedSize
;
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
()
/
PackedSize
;
return
smem_size_b
;
}
...
...
@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
A
DataType
,
typename
Problem
::
B
DataType
,
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
Compute
DataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
f0bbc5db
...
...
@@ -13,14 +13,16 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
Traits_
>
typename
Traits_
,
typename
ComputeDataType_
=
ADataType_
>
struct
GemmPipelineProblemBase
{
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
...
...
@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
ADataType
>>::
PackedSize
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kM
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
ADataType
)
return
pixels_per_thread
<
PackedSize
*
VectorLoadSize
/
sizeof
(
ADataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
ADataType
);
:
PackedSize
*
VectorLoadSize
/
sizeof
(
ADataType
);
}
else
{
...
...
@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentB
()
{
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
BDataType
>>::
PackedSize
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kN
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
BDataType
)
return
pixels_per_thread
<
PackedSize
*
VectorLoadSize
/
sizeof
(
BDataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
BDataType
);
:
PackedSize
*
VectorLoadSize
/
sizeof
(
BDataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
BDataType
);
return
PackedSize
*
VectorLoadSize
/
sizeof
(
BDataType
);
}
}
...
...
@@ -143,9 +149,14 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
Traits_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
Traits_
>
;
typename
Traits_
,
typename
ComputeDataType_
=
ADataType_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
Traits_
,
ComputeDataType_
>
;
template
<
typename
ADataType_
,
typename
BDataType_
,
...
...
@@ -154,14 +165,16 @@ template <typename ADataType_,
typename
Traits_
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
TailNumber
TailNum_
=
TailNumber
::
Full
,
typename
ComputeDataType_
=
ADataType_
>
struct
UniversalGemmPipelineProblem
{
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
f0bbc5db
...
...
@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
PackedSize
=
ck_tile
::
numeric_traits
<
remove_cvref_t
<
DataType
>>::
PackedSize
;
// Assume DataType is even!
if
constexpr
(
XPerTile
%
(
16
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
if
constexpr
(
XPerTile
%
(
PackedSize
*
32
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
PackedSize
*
32
/
sizeof
(
DataType
))
==
0
&&
PackedSize
==
2
)
{
return
(
16
/
sizeof
(
DataType
));
return
(
PackedSize
*
32
/
sizeof
(
DataType
));
}
else
if
constexpr
(
XPerTile
%
(
8
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
else
if
constexpr
(
XPerTile
%
(
PackedSize
*
16
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
PackedSize
*
16
/
sizeof
(
DataType
))
==
0
)
{
return
(
8
/
sizeof
(
DataType
));
return
(
PackedSize
*
16
/
sizeof
(
DataType
));
}
else
if
constexpr
(
sizeof
(
DataType
)
>=
4
&&
XPerTile
%
(
4
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
4
/
sizeof
(
DataType
))
==
0
)
else
if
constexpr
(
XPerTile
%
(
PackedSize
*
8
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
PackedSize
*
8
/
sizeof
(
DataType
))
==
0
)
{
return
(
4
/
sizeof
(
DataType
));
return
(
PackedSize
*
8
/
sizeof
(
DataType
));
}
else
if
constexpr
(
sizeof
(
DataType
)
>=
2
&&
XPerTile
%
(
2
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
2
/
sizeof
(
DataType
))
==
0
)
else
if
constexpr
(
sizeof
(
DataType
)
>=
PackedSize
*
4
&&
XPerTile
%
(
PackedSize
*
4
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
PackedSize
*
4
/
sizeof
(
DataType
))
==
0
)
{
return
(
2
/
sizeof
(
DataType
));
return
(
PackedSize
*
4
/
sizeof
(
DataType
));
}
else
if
constexpr
(
sizeof
(
DataType
)
>=
PackedSize
*
2
&&
XPerTile
%
(
PackedSize
*
2
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
PackedSize
*
2
/
sizeof
(
DataType
))
==
0
)
{
return
(
PackedSize
*
2
/
sizeof
(
DataType
));
}
else
{
return
1
;
return
PackedSize
;
}
}
...
...
@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
A
DataType
,
typename
Problem
::
B
DataType
,
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
Compute
DataType
,
typename
Problem
::
CDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment