"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "d20739121e8c66632ad917987a983238299ebc56"
Commit 35e5c532 authored by aska-0096's avatar aska-0096
Browse files

clang format

parent b010b095
......@@ -252,25 +252,27 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
static auto MakeAGridDescriptor_AK0_M_AK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
static auto MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec)
static auto MakeBGridDescriptor_BK0_N_BK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{});
}
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec)
static auto MakeB1GridDescriptor_BK0_N_BK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
......@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
......@@ -836,7 +842,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
C1DEElementwiseOperation c1de_element_op)
{
constexpr auto dimension = NumDimG + NumDimM + NumDimN;
std::array<index_t, dimension> a_gs_ms_ks_lengths_{};
std::array<index_t, dimension> a_gs_ms_ks_strides_{};
std::array<index_t, dimension> b_gs_ns_ks_lengths_{};
......@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin());
std::copy(a_gs_ms_ks_lengths.begin(),
a_gs_ms_ks_lengths.begin() + dimension,
a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension,
c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension,
c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return Argument{p_a,
......@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin());
std::copy(a_gs_ms_ks_lengths.begin(),
a_gs_ms_ks_lengths.begin() + dimension,
a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension,
c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension,
c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment