"vscode:/vscode.git/clone" did not exist on "1a8a46628eeb54a0ea23990ce74a96790d0625af"
Commit 35e5c532 authored by aska-0096's avatar aska-0096
Browse files

clang format

parent b010b095
......@@ -252,7 +252,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
static auto MakeAGridDescriptor_AK0_M_AK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
......@@ -260,7 +261,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<AK1>{});
}
static auto MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
static auto MakeBGridDescriptor_BK0_N_BK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
......@@ -268,8 +270,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<BK1>{});
}
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
static auto MakeB1GridDescriptor_BK0_N_BK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
......@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
......@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin());
std::copy(a_gs_ms_ks_lengths.begin(),
a_gs_ms_ks_lengths.begin() + dimension,
a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension,
c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension,
c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return Argument{p_a,
......@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin());
std::copy(a_gs_ms_ks_lengths.begin(),
a_gs_ms_ks_lengths.begin() + dimension,
a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension,
c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension,
c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment