Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ef2d6713
Commit
ef2d6713
authored
May 15, 2023
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl
parents
1639689e
a1e344b1
Changes
76
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
869 additions
and
66 deletions
+869
-66
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+0
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+252
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+418
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
...normalization/gridwise_normalization_welford_variance.hpp
+0
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+8
-19
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+59
-1
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+39
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+59
-31
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+21
-2
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+3
-3
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-2
script/cmake-ck-release.sh
script/cmake-ck-release.sh
+1
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
ef2d6713
...
...
@@ -46,7 +46,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
ef2d6713
...
...
@@ -49,7 +49,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
ef2d6713
...
...
@@ -53,7 +53,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_naive_variance.hpp
View file @
ef2d6713
File moved
include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_selector.hpp
View file @
ef2d6713
...
...
@@ -3,8 +3,8 @@
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_welford_variance.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
0 → 100644
View file @
ef2d6713
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
XDataType
,
typename
ComputeDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
>
struct
GridwiseNormalizationSplitK1st
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
ComputeDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
int
GetKPerThread
(
int
kRaw
,
int
kGridSize
,
int
block_k_cluster_id
,
int
thread_k_cluster_id
)
{
bool
is_rightmost_block
=
block_k_cluster_id
==
kGridSize
-
1
;
if
(
is_rightmost_block
)
{
int
left_kPerBlock
=
math
::
integer_divide_ceil
(
kRaw
,
kGridSize
);
int
kPerBlock
=
kRaw
%
kGridSize
==
0
?
left_kPerBlock
:
kRaw
%
left_kPerBlock
;
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
XSrcVectorSize
);
kPerThread
+=
XSrcVectorSize
-
delta
;
});
}
return
kPerThread
;
}
else
{
int
kPerBlock
=
math
::
integer_divide_ceil
(
kRaw
,
kGridSize
);
return
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
}
}
// Calculate mean and variance by welford along k dimension
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarGridDesc_M_KBlock
&
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
p_mean_global
,
MeanVarDataType
*
const
p_variance_global
,
int32_t
*
const
p_welford_count_global
)
{
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
k_grid_size
=
mean_var_grid_desc_m_kblock
.
GetLength
(
I1
);
const
index_t
block_m_cluster_id
=
block_global_id
/
k_grid_size
;
const
index_t
block_k_cluster_id
=
block_global_id
%
k_grid_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
XGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
XSrcVectorSize
));
auto
mean_var_count_store_index
=
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
);
auto
threadwise_welford_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarGridDesc_M_KBlock
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m_kblock
,
mean_var_count_store_index
,
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_variance_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
int
kRaw
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
threadwise_welford
.
max_count_
=
GetKPerThread
(
kRaw
,
k_grid_size
,
block_k_cluster_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
[
i
],
mean_thread_buf
,
var_thread_buf
);
});
}
int
welford_count
=
0
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
// The value of count is same for all I
if
constexpr
(
I
==
MThreadSliceSize
-
1
)
welford_count
=
count
;
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
,
mean_var_grid_desc_m_kblock
,
mean_global_val_buf
);
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
var_thread_buf
,
mean_var_grid_desc_m_kblock
,
var_global_val_buf
);
if
(
block_m_cluster_id
==
0
&&
thread_m_cluster_id
==
0
)
p_welford_count_global
[
block_k_cluster_id
]
=
welford_count
;
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
0 → 100644
View file @
ef2d6713
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_welford_variance.hpp
View file @
ef2d6713
File moved
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
ef2d6713
...
...
@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
...
...
@@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// apply SrcElementwiseOperation on src_vector_container
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
src_v
;
src_element_op_
(
src_v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
src_vector_container
.
template
AsType
<
SrcData
>()(
i
)
=
src_v
;
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src_vector_t
>(
...
...
@@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
// TODO type_convert is not used yet!!!!!
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
...
...
@@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number
<
num_dst_vector
>
{});
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
}
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
DstData
dst_v
;
src_element_op_
(
dst_v
,
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
});
#endif
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
ef2d6713
...
...
@@ -27,6 +27,8 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16
,
mfma_i32_32x32x8i8
,
mfma_i32_16x16x16i8
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
};
...
...
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x16i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x16i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
...
...
@@ -524,17 +570,29 @@ struct MfmaSelector
#endif
}
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
include/ck/utility/amd_xdlops.hpp
View file @
ef2d6713
...
...
@@ -297,6 +297,44 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
template
<
>
struct
intrin_mfma_i32_32x32x16i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x16_i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x32i8
;
template
<
>
struct
intrin_mfma_i32_16x16x32i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f64_16x16x4f64
;
...
...
@@ -306,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
double
&
reg_a
,
const
double
&
reg_b
,
FloatC
&
reg_c
)
{
#ifdef
__gfx90
a
__
#if
def
ined(__gfx90a__) || defined(
__gfx9
4
0__
)
reg_c
.
template
AsType
<
double4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f64_16x16x4f64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
double4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
...
...
include/ck/utility/data_type.hpp
View file @
ef2d6713
...
...
@@ -898,6 +898,8 @@ struct vector_type<T, 256>
}
};
using
int64_t
=
long
;
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
...
...
@@ -974,37 +976,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
...
...
@@ -1062,6 +1033,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
template
<
typename
T
>
struct
NumericLimits
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
ef2d6713
...
...
@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
...
...
@@ -66,8 +67,26 @@ struct ReferenceGemm : public device::BaseOperator
ADataType
v_a
;
BDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
ef2d6713
...
...
@@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification,
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
0
,
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
1
,
1
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
1
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
1
,
2
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
...
@@ -94,7 +94,7 @@ bool profile_gemm_splitk_impl(int do_verification,
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
()
);
c_device_buf
.
SetZero
(
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
...
...
script/cmake-ck-dev.sh
View file @
ef2d6713
...
...
@@ -12,9 +12,8 @@ cmake
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx908;gfx90a
"
\
-D
GPU_TARGETS
=
"gfx908;gfx90a
;gfx940"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D AMDGPU_TARGETS=gfx90a;gfx908
script/cmake-ck-release.sh
View file @
ef2d6713
...
...
@@ -11,9 +11,8 @@ cmake
-D
CMAKE_CXX_FLAGS
=
"-O3"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
OFF
\
-D
GPU_TARGETS
=
"gfx908;gfx90a
"
\
-D
GPU_TARGETS
=
"gfx908;gfx90a
;gfx940"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
#-D AMDGPU_TARGETS=gfx90a;gfx908
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment