Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a085b740
"vscode:/vscode.git/clone" did not exist on "1025f76df1b32b6ec3571ca928d7797a768a3341"
Commit
a085b740
authored
Jul 27, 2022
by
ltqin
Browse files
complete first verison
parent
652728bc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
16 deletions
+35
-16
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+32
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+3
-1
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
a085b740
...
...
@@ -23,10 +23,8 @@ struct BlockwiseSoftmax_V1
{
static_assert
(
MRepeat
==
1
,
"Now MRepeat must equal 1"
);
static
__shared__
AccDataType
p_lex
[
MPerBlock
];
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
WaveSize
=
64
;
...
...
@@ -36,8 +34,9 @@ struct BlockwiseSoftmax_V1
{
__host__
__device__
BlockToMKMap_M0_K_M1Adapt
()
=
default
;
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
__host__
__device__
static
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
{
const
auto
index
=
idx_top
[
I0
];
const
auto
m
=
(
index
/
WaveSize
)
*
MPerXDL
+
index
%
MPerXDL
;
const
auto
k
=
(
index
%
WaveSize
)
/
MPerXDL
;
...
...
@@ -45,6 +44,9 @@ struct BlockwiseSoftmax_V1
}
};
static
constexpr
auto
softmax_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerBlock
>
{},
Number
<
2
>
{}));
constexpr
static
auto
in_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
RegSizePerXdlops
>
{}));
...
...
@@ -89,13 +91,22 @@ struct BlockwiseSoftmax_V1
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
template
<
typename
CThreadBuffer
>
__host__
__device__
static
void
Run
(
CThreadBuffer
&
in_thread_buf
,
void
*
__restrict__
p_shared
)
__host__
__device__
static
void
Run
(
CThreadBuffer
&
in_thread_buf
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_softmax
)
{
// printf("in_thread_desc: {%d, %d, %d}", in_thread_desc.GetLength(I0).value,
// in_thread_desc.GetLength(I1).value, in_thread_desc.GetLength(I2).value);
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_shared
),
BlockSize
);
auto
softmax_lds_buffer
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_softmax
),
MPerBlock
*
2
);
// static auto lds_buffer_m_k = GetSpaceForPreMax();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
auto
thread_cluster_idx
=
BlockToMKMap_M0_K_M1Adapt
::
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
// const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
//
// find max value
//
...
...
@@ -118,9 +129,11 @@ struct BlockwiseSoftmax_V1
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
// save max value
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
1
)))
=
max_value_buf
(
I0
);
// {const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
printf
(
"thread id: %d, Max: %f
\t\t
"
,
thread_local_id
,
max_value_buf
[
I0
]);
//
// softmax
...
...
@@ -150,16 +163,20 @@ struct BlockwiseSoftmax_V1
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
// save sum
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
)))
=
accu_value_buf
(
I0
);
// change elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
/*
static_for<0, NRepeat, 1>{}([&](auto n) {
constexpr index_t in_offset = in_thread_desc.CalculateOffset(make_tuple(0, n, 0));
auto
&
xdlops_out
=
in_thread_buf
.
GetVectorTypeReference
(
Number
<
in_offset
>
{});
auto& xdlops_out =
in_thread_buf.GetVectorTypeReference(Number<in_offset>{});
static_for
<
0
,
in_thread_desc
.
GetLength
(
I2
)
,
1
>
{}([
&
](
auto
iK
)
{
static_for<0,
RegSizePerXdlops
, 1>{}([&](auto iK) {
xdlops_out.template AsType<float>()(iK) =
xdlops_out.template AsType<float>()[iK] / accu_value_buf(I0);
});
});
});
*/
}
};
// namespace ck
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
a085b740
...
...
@@ -474,6 +474,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
num_k_block_main_loop
);
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_lex
[
MPerBlock
*
2
];
using
BlockwiseSoftmax
=
BlockwiseSoftmax_V1
<
BlockSize
,
FloatAcc
,
MPerBlock
,
...
...
@@ -482,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm
.
GetRegSizePerXdlops
(),
MXdlPerWave
,
NXdlPerWave
>
;
BlockwiseSoftmax
::
Run
(
c_thread_buf
,
p_reduce_work_buffer
);
BlockwiseSoftmax
::
Run
(
c_thread_buf
,
p_reduce_work_buffer
,
p_lex
);
}
// output: register to global memory
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment