Unverified Commit 29fa2666 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add gpu driver and improvements to pointwise codegen (#851)



* Add method to compile pointwise

* Formatting

* Add lambda

* Add semicolon

* Rename variable

* Add driver to run jit kernels

* Formatting

* Add context

* Formatting

* Make seperate driver folder

* Add more general gpu driver

* Formatting

* Print out wll time

* Formatting

* Run multiple times and skip first run

* Formatting

* Seperate time_op

* Run an op for comparison

* Formatting

* Add debug asserts

* Formatting

* Change parameer name

* Formatting

* Fix argument order

* Formatting

* Add preloading

* Formatting

* Allow a different data type

* Formatting

* Pipeline transformations

* Formatting

* Add vectorization

* Formatting

* Reduce dims

* Formatting

* Compile with launch params as constant

* Formatting

* Make sure buffer can be vecotrized

* Formatting

* Enable vectorization and preloading

* Formatting

* Add print header

* Formatting

* Avoid allocating to large of LDS

* Formatting

* Add some vec functions to a seperate header

* Formatting

* Add stride loops

* Formatting

* Improve the transform pipeline

* Formatting

* Add const

* Fix shape check

* Formatting

* Just check stride axis is zero

* Remove extra finc_vector_axis overload

* Simplify some mroe functions

* Formatting

* Remove some more extra functions

* Formatting

* Simplify more decltypes

* Add another const

* Fix test

* Get buffer pointer different for older compilers
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarChris Austen <causten@users.noreply.github.com>
parent 30966f6b
......@@ -54,7 +54,7 @@ using namespace migraphx;
extern "C" {
__global__ void kernel(void* x, void* y)
{
make_tensors(x, y)([](auto xt, auto yt) __device__ {
make_tensors()(x, y)([](auto xt, auto yt) __device__ {
auto idx = make_index();
const auto stride = idx.nglobal();
for(index_int i = idx.global; i < xt.get_shape().elements(); i += stride)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment