"...torch/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "861b5ce2bdd697b17b0ce759a5f5f126cd3a8915"
Commit 3908c88b authored by Anthony Chang's avatar Anthony Chang
Browse files

host softmax can run with pre-calculated stats for debug purposes

parent 0aafc6be
...@@ -26,20 +26,34 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -26,20 +26,34 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
AccDataType alpha, AccDataType alpha,
AccDataType beta, AccDataType beta,
const std::vector<index_t> sm_reduce_dims) const std::vector<index_t> sm_reduce_dims,
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) Tensor<AccDataType>* sm_stats_ptr = nullptr)
: in_(in),
out_(out),
alpha_(alpha),
beta_(beta),
sm_reduce_dims_(sm_reduce_dims),
sm_stats_ptr_(sm_stats_ptr)
{ {
// std::cout << "debug: scalar dims: ";
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++) for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{ {
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end()) sm_reduce_dims.end())
{ {
sm_scalar_dims_.push_back(i); sm_stats_dims_.push_back(i);
// std::cout << i << ", ";
} }
} }
// std::cout << std::endl;
for(index_t dim : sm_stats_dims_)
{
sm_stats_lengths_.push_back(in_.mDesc.GetLengths()[dim]);
}
// max and sum reduction with final reduced values of dim=0 is a scalar so give it
// appropriate lengths of {1}
if(sm_stats_dims_.size() == 0)
{
sm_stats_lengths_.push_back(1);
}
} }
const Tensor<InDataType>& in_; const Tensor<InDataType>& in_;
...@@ -47,7 +61,9 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -47,7 +61,9 @@ struct ReferenceSoftmax : public device::BaseOperator
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
std::vector<index_t> sm_reduce_dims_; std::vector<index_t> sm_reduce_dims_;
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction std::vector<index_t> sm_stats_dims_; // dim after internal max/sum reduction
std::vector<size_t> sm_stats_lengths_;
Tensor<AccDataType>* sm_stats_ptr_; // max + ln(sum)
}; };
// Invoker // Invoker
...@@ -55,30 +71,18 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -55,30 +71,18 @@ struct ReferenceSoftmax : public device::BaseOperator
{ {
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
std::vector<size_t> scalar_lengths; Tensor<AccDataType> reduce_max(arg.sm_stats_lengths_);
for(index_t dim : arg.sm_scalar_dims_)
{
scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]);
}
// max and sum reduction with final reduced values of dim=0 is a scalar so give it
// appropriate lengths of {1}
if(arg.sm_scalar_dims_.size() == 0)
{
scalar_lengths.push_back(1);
}
Tensor<AccDataType> reduce_max(scalar_lengths);
reduce_max.GenerateTensorValue( reduce_max.GenerateTensorValue(
GeneratorTensor_1<AccDataType>{std::numeric_limits<AccDataType>::lowest()}); GeneratorTensor_1<AccDataType>{std::numeric_limits<AccDataType>::lowest()});
Tensor<AccDataType> reduce_sum(scalar_lengths); Tensor<AccDataType> reduce_sum(arg.sm_stats_lengths_);
reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); reduce_sum.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
// when final reduced values is of dim=0, the index will be transformed into empty // when final reduced values is of dim=0, the index will be transformed into empty
// std::vector which is actually a valid input for Tensor::operator(std::vector) and // std::vector which is actually a valid input for Tensor::operator(std::vector) and
// internally accesses 0'th element // internally accesses 0'th element
auto to_sm_scalar_idx = [&](auto idx) { auto to_sm_stats_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx; std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_scalar_dims_) for(index_t dim : arg.sm_stats_dims_)
{ {
sm_scalar_idx.push_back(idx[dim]); sm_scalar_idx.push_back(idx[dim]);
} }
...@@ -86,42 +90,66 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -86,42 +90,66 @@ struct ReferenceSoftmax : public device::BaseOperator
}; };
arg.in_.ForEach([&](auto& self, auto idx) { arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max( reduce_max(to_sm_stats_idx(idx)) = std::max(
reduce_max(to_sm_scalar_idx(idx)), ck::type_convert<AccDataType>(self(idx))); reduce_max(to_sm_stats_idx(idx)), ck::type_convert<AccDataType>(self(idx)));
}); });
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl;
Tensor<AccDataType> in_stable(arg.in_.mDesc); Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) { in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x)) // numerator = exp(x - max(x))
self(idx) = std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) - self(idx) = std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx))); reduce_max(to_sm_stats_idx(idx)));
}); });
// LogRangeAsType<float>(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl;
in_stable.ForEach([&](auto& self, auto idx) { in_stable.ForEach([&](auto& self, auto idx) {
// denominator = sum(exp(x - max(x))) // denominator = sum(exp(x - max(x)))
reduce_sum(to_sm_scalar_idx(idx)) += self(idx); reduce_sum(to_sm_stats_idx(idx)) += self(idx);
}); });
// LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") << if(arg.sm_stats_ptr_)
// std::endl; {
arg.sm_stats_ptr_->ForEach([&](auto& self, auto idx) {
self(idx) = reduce_max(idx) + std::log(reduce_sum(idx));
});
}
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
AccDataType temp_result = AccDataType temp_result =
arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_stats_idx(idx)) +
arg.beta_ * self(idx); arg.beta_ * self(idx);
self(idx) = ck::type_convert<OutDataType>(temp_result); self(idx) = ck::type_convert<OutDataType>(temp_result);
}); });
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl; return 0;
// reduction along reduce dims }
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// std::endl; LogRangeAsType<float>(std::cout << "reduce_sum: ", reduce_sum.mData, ",") float RunWithPreCalcStats(const Argument& arg)
// << std::endl; {
if(arg.sm_stats_lengths_ != arg.sm_stats_ptr_[0].GetLengths())
{
throw std::runtime_error(
"softmax stats shape must match shape after softmax sum reduction op");
}
// when final reduced values is of dim=0, the index will be transformed into empty
// std::vector which is actually a valid input for Tensor::operator(std::vector) and
// internally accesses 0'th element
auto to_sm_stats_idx = [&](auto idx) {
std::vector<size_t> sm_scalar_idx;
for(index_t dim : arg.sm_stats_dims_)
{
sm_scalar_idx.push_back(idx[dim]);
}
return sm_scalar_idx;
};
// each element in stats corresponds to max + log(sum) after reduction
// exp(x - max) / sum = exp(x - max) / exp(log(sum)) = exp(x - (max + log(sum)))
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) -
ck::type_convert<AccDataType>(
arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) +
arg.beta_ * self(idx);
});
return 0; return 0;
} }
...@@ -145,9 +173,10 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -145,9 +173,10 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out, Tensor<OutDataType>& out,
AccDataType alpha, AccDataType alpha,
AccDataType beta, AccDataType beta,
const std::vector<index_t> sm_reduce_dims) const std::vector<index_t> sm_reduce_dims,
Tensor<AccDataType>* stats = nullptr)
{ {
return Argument{in, out, alpha, beta, sm_reduce_dims}; return Argument{in, out, alpha, beta, sm_reduce_dims, stats};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -3,9 +3,12 @@ add_custom_target(test_softmax) ...@@ -3,9 +3,12 @@ add_custom_target(test_softmax)
add_gtest_executable(test_softmax_rank3 test_softmax_rank3.cpp) add_gtest_executable(test_softmax_rank3 test_softmax_rank3.cpp)
add_gtest_executable(test_softmax_rank4 test_softmax_rank4.cpp) add_gtest_executable(test_softmax_rank4 test_softmax_rank4.cpp)
add_gtest_executable(test_softmax_interface test_softmax_interface.cpp) add_gtest_executable(test_softmax_interface test_softmax_interface.cpp)
add_gtest_executable(test_softmax_host_ref test_softmax_host_ref.cpp)
target_link_libraries(test_softmax_rank3 PRIVATE utility device_softmax_instance) target_link_libraries(test_softmax_rank3 PRIVATE utility device_softmax_instance)
target_link_libraries(test_softmax_rank4 PRIVATE utility device_softmax_instance) target_link_libraries(test_softmax_rank4 PRIVATE utility device_softmax_instance)
target_link_libraries(test_softmax_interface PRIVATE utility device_softmax_instance) target_link_libraries(test_softmax_interface PRIVATE utility device_softmax_instance)
target_link_libraries(test_softmax_host_ref PRIVATE utility)
add_dependencies(test_softmax test_softmax_rank3) add_dependencies(test_softmax test_softmax_rank3)
add_dependencies(test_softmax test_softmax_rank4) add_dependencies(test_softmax test_softmax_rank4)
add_dependencies(test_softmax test_softmax_interface) add_dependencies(test_softmax test_softmax_interface)
add_dependencies(test_softmax test_softmax_host_ref)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "gtest/gtest.h"
using namespace ck;
TEST(ReferenceSoftmax, Run)
{
Tensor<float> x({2, 2});
Tensor<float> y({2, 2});
x.GenerateTensorValue(GeneratorTensor_Diagonal<float>{});
using ReferenceSoftmax = tensor_operation::host::ReferenceSoftmax<float, float, float>;
float alpha = 1.f;
float beta = 0.f;
auto ref_softmax = ReferenceSoftmax{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(x, y, alpha, beta, {1});
ref_softmax_invoker.Run(ref_softmax_argument);
EXPECT_TRUE((utils::check_err(y.mData, {0.73105858, 0.268941421, 0.26894142, 0.73105858})));
}
TEST(ReferenceSoftmax, RunWithCalculatedStats)
{
// >>> x = np.eye(4)
// >>> m = np.max(np.exp(x), axis=1, keepdims=True)
// >>> l = np.sum(np.exp(x - np.tile(m, (1,4))), axis=1, keepdims=True)
// >>> m + np.log(l)
// array([[1.74366838],
// [1.74366838],
// [1.74366838],
// [1.74366838]])
Tensor<float> x({4, 4});
Tensor<float> y({4, 4});
Tensor<float> stats({4});
x.GenerateTensorValue(GeneratorTensor_Diagonal<float>{});
using ReferenceSoftmax = tensor_operation::host::ReferenceSoftmax<float, float, float>;
float alpha = 1.f;
float beta = 0.f;
auto ref_softmax = ReferenceSoftmax{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
{
auto ref_softmax_argument = ref_softmax.MakeArgument(x, y, alpha, beta, {1}, &stats);
ref_softmax_invoker.Run(ref_softmax_argument);
EXPECT_TRUE(
(utils::check_err(stats.mData, {1.74366838, 1.74366838, 1.74366838, 1.74366838})));
}
{
Tensor<float> yy({4, 4});
auto ref_softmax_argument = ref_softmax.MakeArgument(x, yy, alpha, beta, {1}, &stats);
ref_softmax_invoker.RunWithPreCalcStats(ref_softmax_argument);
EXPECT_TRUE((utils::check_err(y.mData, yy.mData)));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment