Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8dd7156d
Commit
8dd7156d
authored
Jul 25, 2023
by
ltqin
Browse files
Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask
parents
d5f629e7
b5a3ea2d
Changes
533
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
48 additions
and
42 deletions
+48
-42
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+14
-12
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+2
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+10
-5
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+1
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
+1
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
+1
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+1
-1
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+2
-3
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
...gpu/device/convolution_backward_weight_specialization.hpp
+1
-1
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp
...tion/gpu/device/device_batched_contraction_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
.../tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
...nsor_operation/gpu/device/device_batched_gemm_multi_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
...device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+5
-5
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
...tensor_operation/gpu/device/device_batchnorm_backward.hpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
533 of 533+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
static
constexpr
auto
thread_cluster_desc
=
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
CountDataType
>
__device__
static
inline
void
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
Merge
(
T
&
mean_a
,
T
&
var_a
,
CountDataType
&
count_a
,
T
mean_b
,
T
var_b
,
CountDataType
count_b
)
{
{
int
count
=
count_a
+
count_b
;
CountDataType
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
mean_a
+=
delta
*
count_b_over_count
;
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a
=
count
;
count_a
=
count
;
}
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
template
<
typename
CountDataType
>
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
CountDataType
&
count
)
{
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
__shared__
CountDataType
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
CountDataType
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
CountDataType
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -11,10 +11,15 @@
...
@@ -11,10 +11,15 @@
namespace
ck
{
namespace
ck
{
// this version does following things to avoid scratch memory issue
/**
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
* @brief Blockwise data transfer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
*
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template
<
typename
ThreadGroup
,
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
typename
DstElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch
(
s
)
switch
(
s
)
{
{
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
return
"FFilter1x1Stride1Pad0"
;
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -27,9 +27,9 @@ template <index_t NumDimG,
...
@@ -27,9 +27,9 @@ template <index_t NumDimG,
typename
Acc1BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0
ElementwiseOperation
,
typename
C0DE
ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
1DE
ElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedGemmSoftmaxGemmPermute
:
public
BaseOperator
struct
DeviceBatchedGemmSoftmaxGemmPermute
:
public
BaseOperator
{
{
...
@@ -59,9 +59,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
...
@@ -59,9 +59,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0
ElementwiseOperation
ac
c0_element_op
,
C0DE
ElementwiseOperation
c0
de
_element_op
,
B1ElementwiseOperation
b1_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
C
1DE
ElementwiseOperation
c
1de
_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
View file @
8dd7156d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment