"driver/include/host_tensor.hpp" did not exist on "fc98757acd68219eebecb16b15ac472172f6dd55"
Commit 96663815 authored by turneram's avatar turneram
Browse files

Formatting

parent 0ccee797
...@@ -20,7 +20,7 @@ namespace op { ...@@ -20,7 +20,7 @@ namespace op {
struct transposectx struct transposectx
{ {
int head_size = 64; int head_size = 64;
bool reversed_bs = false; bool reversed_bs = false;
template <class Self, class F> template <class Self, class F>
...@@ -46,9 +46,9 @@ struct transposectx ...@@ -46,9 +46,9 @@ struct transposectx
// Output: BxSxNxH // Output: BxSxNxH
argument result{output_shape}; argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) { visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
// TODO: calculate in_offet and out_offset // TODO: calculate in_offet and out_offset
output[i] = input[i]; output[i] = input[i];
}); });
}); });
......
...@@ -20,7 +20,7 @@ namespace op { ...@@ -20,7 +20,7 @@ namespace op {
struct transposeqkv struct transposeqkv
{ {
int head_size = 64; int head_size = 64;
bool reversed_bs = false; bool reversed_bs = false;
template <class Self, class F> template <class Self, class F>
...@@ -47,9 +47,9 @@ struct transposeqkv ...@@ -47,9 +47,9 @@ struct transposeqkv
// K is the number of identical matrix // K is the number of identical matrix
argument result{output_shape}; argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) { visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
// TODO: calculate in_offet and out_offset // TODO: calculate in_offet and out_offset
output[i] = input[i]; output[i] = input[i];
}); });
}); });
......
...@@ -162,8 +162,8 @@ struct parse_attention : op_parser<parse_attention> ...@@ -162,8 +162,8 @@ struct parse_attention : op_parser<parse_attention>
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// scratch3 is BxNxSxH, transpose to output BxSxNxH // scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4 = info.add_instruction( gemm4 = info.add_instruction(migraphx::make_op("transposectx", {{"head_size", head_size}}),
migraphx::make_op("transposectx", {{"head_size", head_size}}), gemm4); gemm4);
gemm4 = info.add_instruction( gemm4 = info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}), make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4)); info.make_contiguous(gemm4));
......
...@@ -89,7 +89,6 @@ __global__ void transposeqkv_kernel(void* input_p, void* output_p) ...@@ -89,7 +89,6 @@ __global__ void transposeqkv_kernel(void* input_p, void* output_p)
} // namespace migraphx } // namespace migraphx
)__migraphx__"; )__migraphx__";
struct transposeqkv_compiler : compiler<transposeqkv_compiler> struct transposeqkv_compiler : compiler<transposeqkv_compiler>
{ {
std::vector<std::string> names() const { return {"transposeqkv"}; } std::vector<std::string> names() const { return {"transposeqkv"}; }
...@@ -121,7 +120,6 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler> ...@@ -121,7 +120,6 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
} }
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -25,32 +25,35 @@ __device__ void transposectx(const T& input_t, const U& output_t, Settings st) ...@@ -25,32 +25,35 @@ __device__ void transposectx(const T& input_t, const U& output_t, Settings st)
// Input: BxNxSxH // Input: BxNxSxH
// Output: BxSxNxH // Output: BxSxNxH
auto head_size = st.head_size; auto head_size = st.head_size;
auto reversed_bs = st.reversed_bs; auto reversed_bs = st.reversed_bs;
int n = threadIdx.y; int n = threadIdx.y;
int s = blockIdx.x; int s = blockIdx.x;
int b = blockIdx.y; int b = blockIdx.y;
int num_heads = blockDim.y; int num_heads = blockDim.y;
int sequence_length = gridDim.x; int sequence_length = gridDim.x;
const int NH = num_heads * head_size; const int NH = num_heads * head_size;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS; const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
int out_offset = 0; int out_offset = 0;
if (reversed_bs) { if(reversed_bs)
{
const int batch_size = gridDim.y; const int batch_size = gridDim.y;
const int BNH = NH * batch_size; const int BNH = NH * batch_size;
out_offset = n * head_size + b * NH + s * BNH; out_offset = n * head_size + b * NH + s * BNH;
} else { }
else
{
out_offset = n * head_size + s * NH + b * NHS; out_offset = n * head_size + s * NH + b * NHS;
} }
const int i = threadIdx.x; const int i = threadIdx.x;
if (i < head_size) { if(i < head_size)
{
output_t[out_offset + i] = input_t[in_offset + i]; output_t[out_offset + i] = input_t[in_offset + i];
} }
} }
......
...@@ -26,33 +26,37 @@ __device__ void transposeqkv(const T& input_t, const U& output_t, Settings st) ...@@ -26,33 +26,37 @@ __device__ void transposeqkv(const T& input_t, const U& output_t, Settings st)
// Output: KxBxNxSxH // Output: KxBxNxSxH
// K is the number of identical matrix // K is the number of identical matrix
auto H = st.head_size; auto H = st.head_size;
auto reversed_bs = st.reversed_bs; auto reversed_bs = st.reversed_bs;
int n = threadIdx.y; int n = threadIdx.y;
int s = blockIdx.x; int s = blockIdx.x;
int b = blockIdx.y; int b = blockIdx.y;
int m = blockIdx.z; // matrix id int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y; const int num_heads = blockDim.y;
const int sequence_length = gridDim.x; const int sequence_length = gridDim.x;
const int batch_size = gridDim.y; const int batch_size = gridDim.y;
const int chunk_num = gridDim.z; const int chunk_num = gridDim.z;
const int NH = num_heads * H; const int NH = num_heads * H;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
int in_offset = 0; int in_offset = 0;
if (reversed_bs) { if(reversed_bs)
{
const int BNH = NH * batch_size; const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num; in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
} else { }
else
{
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num; in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
} }
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x; const int i = threadIdx.x;
if (i < H) { if(i < H)
{
output_t[out_offset + i] = input_t[in_offset + i]; output_t[out_offset + i] = input_t[in_offset + i];
} }
} }
......
...@@ -10,7 +10,8 @@ struct test_transposectx : verify_program<test_transposectx> ...@@ -10,7 +10,8 @@ struct test_transposectx : verify_program<test_transposectx>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}}); auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}});
mm->add_instruction(migraphx::make_op("transposectx", {{"head_size", 64}}), x); mm->add_instruction(migraphx::make_op("transposectx", {{"head_size", 64}}), x);
p.debug_print(); p.debug_print();
return p; return p;
......
...@@ -9,7 +9,8 @@ struct test_transposeqkv : verify_program<test_transposeqkv> ...@@ -9,7 +9,8 @@ struct test_transposeqkv : verify_program<test_transposeqkv>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 12, 64}}); auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 12, 64}});
mm->add_instruction(migraphx::make_op("transposeqkv", {{"head_size", 64}}), x); mm->add_instruction(migraphx::make_op("transposeqkv", {{"head_size", 64}}), x);
p.debug_print(); p.debug_print();
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment