Unverified Commit fe83843b authored by HunterTracer's avatar HunterTracer Committed by GitHub
Browse files

add SCATTER_API definition for scatter_mul in scatter.cpp & scatter.h (#344)


Co-authored-by: Hunter_Tracer's avatarzenghongtai <1518445275@qq.com>
parent 111ffc42
...@@ -239,7 +239,8 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -239,7 +239,8 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
} }
torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim, SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0]; return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
......
...@@ -15,6 +15,11 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -15,6 +15,11 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size); torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment