"...git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "da1e83cc341165d5c9e35ce8d54647780e4bf4e1"
Commit fbcb4570 authored by Paul's avatar Paul
Browse files

Latest

parent 60c6738a
......@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
{
if(z == 0)
return 0;
return (2.0 / z) - 1.0;
const auto max = 32768;
const double range = max / 2;
double result = (z % max) / range;
result -= 1;
return result;
}
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})>
......@@ -54,11 +58,28 @@ struct xorshf96_generator
}
};
template <class T>
struct xorshift_generator
{
unsigned long x;
xorshift_generator(unsigned long seed = 0) : x(521288629ULL ^ seed) {}
constexpr T operator()() noexcept
{
x ^= x >> 12U;
x ^= x << 25U;
x ^= x >> 27U;
return normalize<T>(x * 0x2545F4914F6CDD1D);
}
};
template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
{
std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
return result;
}
......
......@@ -16,7 +16,23 @@ using vec4 = T __attribute__((ext_vector_type(4)));
template <class T>
vec4<T>* as_vec4(T* x)
{
std::uintptr_t a = reinterpret_cast<std::uintptr_t>(x);
if(a % 32 != 0)
throw std::runtime_error("Memory not aligned for vector operations");
return reinterpret_cast<vec4<T>*>(x);
// return (vec4<T>*)(x);
}
template <class T>
vec4<T> vec4_load(T* x, size_t i)
{
vec4<T> result;
auto n = i * 4;
result[0] = x[n + 0];
result[1] = x[n + 1];
result[2] = x[n + 2];
result[3] = x[n + 3];
return result;
}
template <class... Ts>
......@@ -64,6 +80,33 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
#if 1
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = i % bdim_len;
auto b = buffer[bidx];
type x = xp[i];
outp[i] = f(x, b);
}
});
#else
auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data());
......@@ -72,18 +115,31 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t rem = output.size() % vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
const std::size_t bdim_vec_rem = bdim_len % vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ {
__shared__ vec4<type> buffer[2048];
for(size_t i = idx.local; i < bdim_len / vec_size; i += nlocal)
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = yp[i];
}
// Load remaining bias data
for(size_t i = idx.local; i < bdim_vec_rem; i += nlocal)
{
buffer[bdim_vec_len][i] = yp[bdim_vec_len][i];
}
for(size_t i = idx.local; i < (vec_size-bdim_vec_rem); i += nlocal)
{
buffer[bdim_vec_len][i] = yp[0][i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = i % bdim_vec_len;
auto bidx = bdim_vec_len == 0 ? 0 : i % bdim_vec_len;
auto b = buffer[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
......@@ -93,7 +149,12 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
}
outp[i] = out;
}
for(size_t i = idx.global; i < rem; i += nglobal)
{
outp[n][i] = f(xp[n][i], buffer[bdim_vec_len][i]);
}
});
#endif
});
};
}
......@@ -157,20 +218,26 @@ auto nary(argument result, Arguments... args)
inline auto nary(argument result, argument arg1, argument arg2)
{
return [=](auto f) {
// TODO: Check for one broadcast stride
// TODO: Check result and arg1 shape is the same
// TODO: CHeck that broadcast shape doesnt have more than 2048 elements
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
std::count_if(arg2.get_shape().strides().begin(),
arg2.get_shape().strides().end(),
[](auto x) { return x != 0; }) == 1)
{
binary_broadcast(result, arg1, arg2)(f);
}
else
{
nary_impl(result, arg1, arg2)(f);
auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides();
auto stride_it = std::find_if(strides.begin(),
strides.end(), not_zero);
auto stride_idx = std::distance(strides.begin(), stride_it);
auto stride_len = arg2.get_shape().lens()[stride_idx];
// TODO: Dont require disibility by 4
bool divisible_by_4 = (stride_len % 4 == 0) and (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4 and stride_len <= 2048 and std::none_of(std::next(stride_it), strides.end(), not_zero)) {
binary_broadcast(result, arg1, arg2)(f);
return;
}
}
nary_impl(result, arg1, arg2)(f);
};
}
......
......@@ -20,7 +20,7 @@ void eliminate_allocation::apply(program& p) const
continue;
allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes();
n += size + (size % 4);
n += size + (size % 32);
}
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs)
......
......@@ -77,6 +77,12 @@ struct auto_print
};
std::array<std::function<void()>, 2> auto_print::handlers = {};
template<class T>
auto get_hash(const T& x)
{
return std::hash<T>{}(x);
}
void compile_check(migraph::program& p, const migraph::target& t)
{
auto name = t.name();
......@@ -98,10 +104,9 @@ migraph::argument run_cpu()
auto_print pp{p, 0};
compile_check(p, migraph::cpu::cpu_target{});
migraph::program::parameter_map m;
int seed = 0;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::generate_argument(x.second, seed++);
m[x.first] = migraph::generate_argument(x.second, get_hash(x.first));
}
return p.eval(m);
}
......@@ -113,12 +118,10 @@ migraph::argument run_gpu()
auto p = v.create_program();
auto_print pp{p, 1};
compile_check(p, migraph::gpu::target{});
migraph::program::parameter_map m;
int seed = 0;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, seed++));
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second, get_hash(x.first)));
}
return migraph::gpu::from_gpu(p.eval(m));
......@@ -133,8 +136,10 @@ void verify_args(const std::string& name,
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
// std::cout << cpu << std::endl;
// std::cout << gpu << std::endl;
if(cpu.size() < 32)
std::cout << "cpu:" << cpu << std::endl;
if(gpu.size() < 32)
std::cout << "gpu:" << gpu << std::endl;
if(migraph::range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(migraph::range_zero(gpu))
......@@ -156,6 +161,7 @@ void verify_args(const std::string& name,
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
std::cout << std::endl;
}
});
}
......@@ -226,6 +232,34 @@ struct test_add_broadcast2
}
};
struct test_add_broadcast3
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 4, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {4}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by);
return p;
}
};
struct test_add_broadcast4
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", {migraph::shape::float_type, {2, 3, 5}});
auto y = p.add_parameter("y", {migraph::shape::float_type, {3}});
auto by = p.add_instruction(migraph::broadcast{1}, x, y);
p.add_instruction(migraph::add{}, x, by);
return p;
}
};
struct test_conv_relu
{
migraph::program create_program() const
......@@ -435,6 +469,8 @@ int main()
verify_program<test_add>();
verify_program<test_add_broadcast>();
verify_program<test_add_broadcast2>();
verify_program<test_add_broadcast3>();
verify_program<test_add_broadcast4>();
verify_program<test_conv_relu>();
verify_program<test_add_relu>();
verify_program<test_conv_pooling>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment