Commit 8dd7156d authored by ltqin's avatar ltqin
Browse files

Merge branch 'mha-train-develop' into attn-train-develop-qloop-mask

parents d5f629e7 b5a3ea2d
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
namespace ck { namespace ck {
...@@ -35,10 +35,11 @@ struct BlockwiseWelford ...@@ -35,10 +35,11 @@ struct BlockwiseWelford
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
template <typename CountDataType>
__device__ static inline void __device__ static inline void
Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b)
{ {
int count = count_a + count_b; CountDataType count = count_a + count_b;
T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count; T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
T delta = mean_b - mean_a; T delta = mean_b - mean_a;
mean_a += delta * count_b_over_count; mean_a += delta * count_b_over_count;
...@@ -46,11 +47,12 @@ struct BlockwiseWelford ...@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a = count; count_a = count;
} }
__device__ static void Run(T& mean_value, T& var_value, int& count) template <typename CountDataType>
__device__ static void Run(T& mean_value, T& var_value, CountDataType& count)
{ {
__shared__ T mean_block_buf[BlockSize]; __shared__ T mean_block_buf[BlockSize];
__shared__ T var_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize];
__shared__ int count_block_buf[BlockSize]; __shared__ CountDataType count_block_buf[BlockSize];
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
...@@ -76,13 +78,13 @@ struct BlockwiseWelford ...@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
T mean1 = mean_block_buf[offset1]; T mean1 = mean_block_buf[offset1];
T var1 = var_block_buf[offset1]; T var1 = var_block_buf[offset1];
int count1 = count_block_buf[offset1]; CountDataType count1 = count_block_buf[offset1];
T mean2 = mean_block_buf[offset2]; T mean2 = mean_block_buf[offset2];
T var2 = var_block_buf[offset2]; T var2 = var_block_buf[offset2];
int count2 = count_block_buf[offset2]; CountDataType count2 = count_block_buf[offset2];
Merge(mean1, var1, count1, mean2, var2, count2); Merge(mean1, var1, count1, mean2, var2, count2);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp" #include "ck/utility/get_shift.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -11,10 +11,15 @@ ...@@ -11,10 +11,15 @@
namespace ck { namespace ck {
// this version does following things to avoid scratch memory issue /**
// 1. Use StaticallyIndexedArray instead of C array for thread buffer * @brief Blockwise data transfer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor *
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate * This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup, template <typename ThreadGroup,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat ...@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch(s) switch(s)
{ {
case ConvolutionBackwardDataSpecialization::Default: return "Default"; case ConvolutionBackwardDataSpecialization::Default: return "Default";
case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
return "FFilter1x1Stride1Pad0";
default: return "Unrecognized specialization!"; default: return "Unrecognized specialization!";
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -27,9 +27,9 @@ template <index_t NumDimG, ...@@ -27,9 +27,9 @@ template <index_t NumDimG,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation, typename C0DEElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{ {
...@@ -59,9 +59,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator ...@@ -59,9 +59,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, C0DEElementwiseOperation c0de_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0; C1DEElementwiseOperation c1de_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment