Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c54b7bc9
Commit
c54b7bc9
authored
Sep 19, 2022
by
Chao Liu
Browse files
gMerge remote-tracking branch 'origin/develop' into group_norm
parents
9a8967a4
f584ab0c
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
558 additions
and
57 deletions
+558
-57
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
+339
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+1
-0
include/ck/utility/span.hpp
include/ck/utility/span.hpp
+67
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+17
-21
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+25
-13
library/include/ck/library/utility/fill.hpp
library/include/ck/library/utility/fill.hpp
+12
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+45
-13
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+19
-0
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
+13
-0
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
+13
-8
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
c54b7bc9
...
@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc
c_new
=
FloatGemmAcc
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
// O_new
running_sum_new
[
iM
];
// Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf
(
I
)
=
c_new
;
c_thread_buf
(
I
)
=
c_new
;
// O_new
});
});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
View file @
c54b7bc9
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
auto
in_global_buf_tuple
=
generate_tuple
(
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
static_assert
(
in_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
},
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
auto
out_global_buf_tuple
=
generate_tuple
(
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
static_assert
(
out_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
},
...
...
include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp
0 → 100644
View file @
c54b7bc9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <numeric>
#include <iterator>
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwisePermute
,
typename
InGridDesc
,
typename
OutGridDesc
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
typename
Block2TileMap
>
__global__
void
kernel_nd_permute
(
const
InGridDesc
in_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
InDataType
*
p_in_global
,
OutDataType
*
p_out_global
,
const
ElementwiseOperation
elementwise_op
,
const
Block2TileMap
block_2_tile_map
)
{
__shared__
char
p_shared
[
GridwisePermute
::
GetSharedMemoryNumberOfByte
()];
GridwisePermute
::
Run
(
in_grid_desc
,
out_grid_desc
,
p_in_global
,
p_out_global
,
p_shared
,
elementwise_op
,
block_2_tile_map
);
}
template
<
typename
InGridDesc
,
typename
OutGridDesc
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
index_t
NPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
GridwisePermute
{
static_assert
(
InGridDesc
::
GetNumOfDimension
()
==
OutGridDesc
::
GetNumOfDimension
());
static_assert
(
3
<=
InGridDesc
::
GetNumOfDimension
());
static_assert
((
InGridDesc
::
GetNumOfDimension
()
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
InGridDesc
::
GetNumOfDimension
());
static_assert
((
OutGridDesc
::
GetNumOfDimension
()
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
OutGridDesc
::
GetNumOfDimension
());
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
struct
Block2TileMap
{
static
constexpr
index_t
NumDim
=
InGridDesc
::
GetNumOfDimension
();
static_assert
(
3
<=
NumDim
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
Block2TileMap
()
=
delete
;
Block2TileMap
(
const
Block2TileMap
&
)
=
default
;
Block2TileMap
(
Block2TileMap
&&
)
=
delete
;
~
Block2TileMap
()
=
default
;
Block2TileMap
&
operator
=
(
const
Block2TileMap
&
)
=
delete
;
Block2TileMap
&
operator
=
(
Block2TileMap
&&
)
=
delete
;
explicit
Block2TileMap
(
const
InGridDesc
&
desc
)
:
desc_
(
desc
)
{}
__host__
constexpr
index_t
CalculateGridSize
(
const
InGridDesc
&
desc
)
const
{
const
auto
N0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
3
>
{}),
NPerBlock
);
const
auto
H0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
2
>
{}),
HPerBlock
);
const
auto
W0
=
math
::
integer_divide_ceil
(
desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}),
WPerBlock
);
const
index_t
grid_size
=
N0
*
H0
*
W0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
Size
()
==
1
);
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
N0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
3
>
{}),
NPerBlock
);
const
auto
H0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
2
>
{}),
HPerBlock
);
const
auto
W0
=
math
::
integer_divide_ceil
(
desc_
.
GetLength
(
Number
<
NumDim
-
1
>
{}),
WPerBlock
);
block_1d_id
=
block_1d_id
%
(
N0
*
H0
*
W0
);
index_t
idx_N0
=
block_1d_id
/
(
H0
*
W0
);
index_t
idx_H0
=
(
block_1d_id
%
(
H0
*
W0
))
/
W0
;
index_t
idx_W0
=
block_1d_id
%
W0
;
return
make_tuple
(
idx_N0
,
idx_H0
,
idx_W0
);
}
private:
const
InGridDesc
desc_
;
};
using
DefaultBlock2TileMap
=
Block2TileMap
;
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
__host__
__device__
static
constexpr
auto
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NPerBlock
>
{},
Number
<
HPerBlock
>
{},
Number
<
WPerBlock
>
{}),
make_tuple
(
Number
<
HPerBlock
*
(
WPerBlock
+
InBlockLdsExtraW
)
>
{},
Number
<
WPerBlock
+
InBlockLdsExtraW
>
{},
I1
));
}
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
template
<
typename
GridDesc
>
__host__
__device__
static
constexpr
auto
GetMergedDesc
(
const
GridDesc
&
desc
)
{
constexpr
index_t
NumDim
=
GridDesc
::
GetNumOfDimension
();
static_assert
(
3
<=
NumDim
);
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
generate_tuple
(
[
&
](
auto
I
)
{
return
desc
.
GetLength
(
I
);
},
Number
<
NumDim
-
2
>
{})),
make_pass_through_transform
(
desc
.
GetLength
(
Number
<
NumDim
-
2
>
{})),
make_pass_through_transform
(
desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}))),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
merged_desc
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
in_block_desc_nperblock_hperblock_wperblock
=
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
();
return
in_block_desc_nperblock_hperblock_wperblock
.
GetElementSpaceSize
()
*
sizeof
(
InDataType
);
}
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2TileMap
(
const
InGridDesc
&
desc
)
{
return
DefaultBlock2TileMap
{
desc
};
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
InGridDesc
&
in_grid_desc
,
const
OutGridDesc
&
out_grid_desc
)
{
constexpr
index_t
NumDim
=
InGridDesc
::
GetNumOfDimension
();
// check if we only swap last 2 dimensions
bool
valid
=
true
;
static_for
<
0
,
NumDim
-
2
,
1
>
{}([
&
](
auto
I
)
{
if
(
valid
&&
in_grid_desc
.
GetLength
(
I
)
!=
out_grid_desc
.
GetLength
(
I
))
{
valid
=
false
;
}
});
return
valid
&&
(
in_grid_desc
.
GetLength
(
Number
<
NumDim
-
1
>
{})
==
out_grid_desc
.
GetLength
(
Number
<
NumDim
-
2
>
{}))
&&
(
in_grid_desc
.
GetLength
(
Number
<
NumDim
-
2
>
{})
==
out_grid_desc
.
GetLength
(
Number
<
NumDim
-
1
>
{}));
}
template
<
typename
Block2TileMap
>
__device__
static
void
Run
(
const
InGridDesc
in_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
InDataType
*
p_in_global
,
OutDataType
*
p_out_global
,
void
*
__restrict__
p_shared
,
const
ElementwiseOperation
elementwise_op
,
const
Block2TileMap
&
block_2_tile_map
)
{
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
NPerBlock
);
const
index_t
h_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
HPerBlock
);
const
index_t
w_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
WPerBlock
);
// create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
constexpr
auto
in_block_desc_nperblock_hperblock_wperblock
=
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock
();
auto
in_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
InDataType
*>
(
p_shared
),
in_block_desc_nperblock_hperblock_wperblock
.
GetElementSpaceSize
());
using
BlockSliceLengths
=
Sequence
<
NPerBlock
,
HPerBlock
,
WPerBlock
>
;
using
InBlockTransferAccessOrder
=
Sequence
<
0
,
1
,
2
>
;
constexpr
index_t
SrcVectorDimAfterMerge
=
SrcVectorDim
-
(
InGridDesc
::
GetNumOfDimension
()
-
3
);
constexpr
index_t
DstVectorDimAfterMerge
=
SrcVectorDimAfterMerge
;
using
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
const
auto
in_grid_desc_n_h_w
=
GetMergedDesc
(
in_grid_desc
);
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
auto
in_global_load
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ElementwiseOperation
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
BlockSliceLengths
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
InDataType
,
InDataType
,
decltype
(
in_grid_desc_n_h_w
),
decltype
(
in_block_desc_nperblock_hperblock_wperblock
),
InBlockTransferAccessOrder
,
InBlockTransferAccessOrder
,
SrcVectorDimAfterMerge
,
2
,
SrcScalarPerVector
,
1
,
1
,
1
,
true
,
true
>
(
in_grid_desc_n_h_w
,
make_multi_index
(
n_block_data_idx_on_grid
,
h_block_data_idx_on_grid
,
w_block_data_idx_on_grid
),
PassThrough
{},
in_block_desc_nperblock_hperblock_wperblock
,
make_multi_index
(
0
,
0
,
0
),
PassThrough
{});
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
const
auto
out_grid_desc_n_w_h
=
GetMergedDesc
(
out_grid_desc
);
// create transposed view of output tensor
const
auto
out_grid_desc_n_h_w
=
transform_tensor_descriptor
(
out_grid_desc_n_w_h
,
make_tuple
(
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I0
)),
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I1
)),
make_pass_through_transform
(
out_grid_desc_n_w_h
.
GetLength
(
I2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
2
>
{},
Sequence
<
1
>
{}));
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
auto
out_global_store
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ElementwiseOperation
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
BlockSliceLengths
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
InDataType
,
OutDataType
,
decltype
(
in_block_desc_nperblock_hperblock_wperblock
),
decltype
(
out_grid_desc_n_h_w
),
InBlockTransferAccessOrder
,
InBlockTransferAccessOrder
,
2
,
DstVectorDimAfterMerge
,
1
,
DstScalarPerVector
,
1
,
1
,
true
,
true
>
(
in_block_desc_nperblock_hperblock_wperblock
,
make_multi_index
(
0
,
0
,
0
),
PassThrough
{},
out_grid_desc_n_h_w
,
make_multi_index
(
n_block_data_idx_on_grid
,
h_block_data_idx_on_grid
,
w_block_data_idx_on_grid
),
elementwise_op
);
in_global_load
.
Run
(
in_grid_desc_n_h_w
,
in_global_buf
,
in_block_desc_nperblock_hperblock_wperblock
,
in_block_buf
,
I0
);
out_global_store
.
Run
(
in_block_desc_nperblock_hperblock_wperblock
,
in_block_buf
,
out_grid_desc_n_h_w
,
out_global_buf
,
I0
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
c54b7bc9
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/utility/span.hpp
0 → 100644
View file @
c54b7bc9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstddef>
#include <array>
#include <type_traits>
namespace
ck
{
template
<
typename
T
>
class
span
{
public:
using
element_type
=
T
;
using
value_type
=
std
::
remove_cv_t
<
element_type
>
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
pointer
=
element_type
*
;
using
const_pointer
=
const
element_type
*
;
using
reference
=
element_type
&
;
using
const_reference
=
const
element_type
&
;
using
iterator
=
pointer
;
using
const_iterator
=
pointer
;
constexpr
span
()
:
span
(
nullptr
,
size_type
{
0
})
{}
constexpr
span
(
pointer
first
,
size_type
count
)
:
ptr_
(
first
),
size_
(
count
)
{}
constexpr
span
(
pointer
first
,
pointer
last
)
:
span
(
first
,
last
-
first
)
{}
template
<
std
::
size_t
N
>
constexpr
span
(
element_type
(
&
arr
)[
N
])
noexcept
:
span
(
arr
,
N
)
{
}
template
<
std
::
size_t
N
>
constexpr
span
(
std
::
array
<
value_type
,
N
>&
arr
)
noexcept
:
span
(
arr
.
data
(),
N
)
{
}
template
<
typename
Container
>
constexpr
span
(
const
Container
&
container
)
:
span
(
container
.
data
(),
container
.
size
())
{
}
constexpr
iterator
begin
()
const
noexcept
{
return
ptr_
;
}
constexpr
const_iterator
cbegin
()
const
noexcept
{
return
begin
();
}
constexpr
iterator
end
()
const
noexcept
{
return
begin
()
+
size
();
}
constexpr
const_iterator
cend
()
const
noexcept
{
return
end
();
}
constexpr
reference
front
()
const
{
return
*
begin
();
}
constexpr
reference
back
()
const
{
return
*
(
--
end
());
}
constexpr
reference
operator
[](
size_type
idx
)
const
{
return
*
(
begin
()
+
idx
);
}
constexpr
pointer
data
()
const
noexcept
{
return
ptr_
;
}
constexpr
size_type
size
()
const
noexcept
{
return
size_
;
}
private:
pointer
ptr_
;
size_type
size_
;
};
}
// namespace ck
include/ck/utility/transpose_vectors.hpp
View file @
c54b7bc9
...
@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
...
@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
y0 = vy0.template AsType<half2_t>()[I0];
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
#else
asm
volatile
(
"
\n
\
constexpr
int32_t
m0
=
0x05040100
;
v_pack_b32_f16 %0, %1, %2
\n
\
constexpr
int32_t
m1
=
0x07060302
;
"
:
"=v"
(
y0
)
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
:
"v"
(
x0
),
"v"
(
x1
));
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
asm
volatile
(
"
\n
\
// index is reversed because of little endianness (least significant bits first)
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1]
\n
\
y0
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
));
"
y1
=
bit_cast
<
half2_t
>
(
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m1
));
:
"=v"
(
y1
)
:
"v"
(
x0
),
"v"
(
x1
));
#endif
#endif
}
}
...
@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
...
@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
// -- -- -- -- -- -- -- -- - - - -
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
// index is reversed because of little endianness (least significant bits first)
// clang-format off
t0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m0
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t0
)
:
"v"
(
bit_cast
<
int32_t
>
(
x1
)),
"v"
(
bit_cast
<
int32_t
>
(
x0
)),
"s"
(
m0
));
t1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x3
),
bit_cast
<
int32_t
>
(
x2
),
m0
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t1
)
:
"v"
(
bit_cast
<
int32_t
>
(
x3
)),
"v"
(
bit_cast
<
int32_t
>
(
x2
)),
"s"
(
m0
));
z0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z0
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m1
));
z1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z1
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m2
));
t0
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x1
),
bit_cast
<
int32_t
>
(
x0
),
m3
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t0
)
:
"v"
(
bit_cast
<
int32_t
>
(
x1
)),
"v"
(
bit_cast
<
int32_t
>
(
x0
)),
"s"
(
m3
));
t1
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
x3
),
bit_cast
<
int32_t
>
(
x2
),
m3
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
t1
)
:
"v"
(
bit_cast
<
int32_t
>
(
x3
)),
"v"
(
bit_cast
<
int32_t
>
(
x2
)),
"s"
(
m3
));
z2
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m1
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z2
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m1
));
z3
=
__builtin_amdgcn_perm
(
bit_cast
<
int32_t
>
(
t1
),
bit_cast
<
int32_t
>
(
t0
),
m2
);
asm
volatile
(
"v_perm_b32 %0, %1, %2, %3"
:
"=v"
(
z3
)
:
"v"
(
bit_cast
<
int32_t
>
(
t1
)),
"v"
(
bit_cast
<
int32_t
>
(
t0
)),
"s"
(
m2
));
// clang-format on
y0
=
bit_cast
<
int8x4_t
>
(
z0
);
y0
=
bit_cast
<
int8x4_t
>
(
z0
);
y1
=
bit_cast
<
int8x4_t
>
(
z1
);
y1
=
bit_cast
<
int8x4_t
>
(
z1
);
...
...
library/include/ck/library/utility/check_err.hpp
View file @
c54b7bc9
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/host_utility/io.hpp"
...
@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out,
...
@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out,
...
@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
...
@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out,
...
@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out,
...
@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out,
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out,
...
@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
...
@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out,
...
@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
template
<
typename
T
>
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>
::
value
,
bool
>::
type
typename
std
::
enable_if
<
std
::
is_same
_v
<
T
,
half_t
>
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>
&
out
,
check_err
(
span
<
const
T
>
out
,
const
std
::
vector
<
T
>
&
ref
,
span
<
const
T
>
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
double
atol
=
1e-3
)
{
{
if
(
out
.
size
()
!=
ref
.
size
())
if
(
out
.
size
()
!=
ref
.
size
())
{
{
std
::
c
out
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
std
::
c
err
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out,
...
@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out,
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
5
)
{
{
std
::
c
out
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
c
err
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
}
res
=
false
;
res
=
false
;
...
@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out,
...
@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
half_t
>::
value
,
bool
>::
type
check_err
(
const
std
::
vector
<
T
>&
out
,
const
std
::
vector
<
T
>&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
)
{
return
check_err
(
span
<
const
T
>
{
out
},
span
<
const
T
>
{
ref
},
msg
,
rtol
,
atol
);
}
template
<
typename
T
>
template
<
typename
T
>
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
std
::
enable_if_t
<
(
std
::
is_integral_v
<
T
>
&&
!
std
::
is_same_v
<
T
,
bhalf_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
...
@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out,
...
@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out,
}
}
if
(
!
res
)
if
(
!
res
)
{
{
std
::
c
out
<<
"max err: "
<<
max_err
<<
std
::
endl
;
std
::
c
err
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
}
return
res
;
return
res
;
}
}
...
...
library/include/ck/library/utility/fill.hpp
View file @
c54b7bc9
...
@@ -5,7 +5,10 @@
...
@@ -5,7 +5,10 @@
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <iterator>
#include <random>
#include <random>
#include <type_traits>
#include <utility>
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
...
@@ -25,6 +28,15 @@ struct FillUniformDistribution
...
@@ -25,6 +28,15 @@ struct FillUniformDistribution
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
dis
(
gen
));
});
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
std
::
declval
<
FillUniformDistribution
>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
};
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
c54b7bc9
...
@@ -3,15 +3,16 @@
...
@@ -3,15 +3,16 @@
#pragma once
#pragma once
#include <thread>
#include <vector>
#include <numeric>
#include <algorithm>
#include <algorithm>
#include <utility>
#include <cassert>
#include <cassert>
#include <iostream>
#include <iostream>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
template
<
typename
Range
>
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
...
@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
...
@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
template
<
typename
T
>
template
<
typename
T
>
struct
Tensor
struct
Tensor
{
{
using
Descriptor
=
HostTensorDescriptor
;
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
{
{
...
@@ -251,7 +255,7 @@ struct Tensor
...
@@ -251,7 +255,7 @@ struct Tensor
{
{
}
}
Tensor
(
const
HostTensor
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
Tensor
<
OutT
>
CopyAsType
()
const
...
@@ -278,9 +282,9 @@ struct Tensor
...
@@ -278,9 +282,9 @@ struct Tensor
{
{
}
}
const
std
::
vector
<
std
::
size_t
>&
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
decltype
(
auto
)
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
const
std
::
vector
<
std
::
size_t
>&
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
decltype
(
auto
)
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
std
::
size_t
GetNumOfDimension
()
const
{
return
mDesc
.
GetNumOfDimension
();
}
std
::
size_t
GetNumOfDimension
()
const
{
return
mDesc
.
GetNumOfDimension
();
}
...
@@ -288,6 +292,8 @@ struct Tensor
...
@@ -288,6 +292,8 @@ struct Tensor
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
void
SetZero
()
void
SetZero
()
{
{
for
(
auto
&
v
:
mData
)
for
(
auto
&
v
:
mData
)
...
@@ -425,14 +431,40 @@ struct Tensor
...
@@ -425,14 +431,40 @@ struct Tensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
typename
std
::
vector
<
T
>::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
typename
std
::
vector
<
T
>::
iterator
end
()
{
return
mData
.
end
();
}
typename
Data
::
pointer
data
()
{
return
mData
.
data
();
}
typename
std
::
vector
<
T
>
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
Data
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
std
::
vector
<
T
>::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_pointer
data
()
const
{
return
mData
.
data
();
}
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
add_const_t
<
std
::
remove_reference_t
<
U
>>
;
return
ck
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
template
<
typename
U
=
T
>
auto
AsSpan
()
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
remove_reference_t
<
U
>
;
return
ck
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
HostTensor
Descriptor
mDesc
;
Descriptor
mDesc
;
std
::
vector
<
T
>
mData
;
Data
mData
;
};
};
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
c54b7bc9
...
@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
...
@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
// clang-format on
// clang-format on
>
;
>
;
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances
=
std
::
tuple
<
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
256
,
128
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
256
,
128
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
256
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
256
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
40
,
64
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
40
,
128
,
32
,
4
,
4
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
void
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemm
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemm
<
Row
,
Col
,
Col
,
...
@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
...
@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
add_device_operation_instances
(
instances
,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
View file @
c54b7bc9
...
@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
...
@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16_IrregularK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{{
256
,
256
,
160
,
160
,
16
},
{
256
,
64
,
160
,
64
,
16
},
{
1024
,
1024
,
80
,
80
,
16
},
{
1024
,
64
,
80
,
64
,
16
},
{
4096
,
4096
,
40
,
40
,
16
},
{
4096
,
64
,
40
,
64
,
16
}};
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
// TODO: enable KPadding tests when it is implemented
...
...
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
View file @
c54b7bc9
...
@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
...
@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
{
256
,
256
,
160
,
160
,
4
},
};
{
256
,
64
,
160
,
64
,
4
},
{
1024
,
1024
,
80
,
80
,
2
},
{
1024
,
64
,
80
,
64
,
2
},
{
4096
,
4096
,
40
,
40
,
1
},
{
4096
,
64
,
40
,
64
,
1
}};
bool
bench_
=
false
;
bool
bench_
=
false
;
bool
verify_
=
true
;
bool
verify_
=
true
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment