"docs/source/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "a9e4574261a20d4ada213d26671da7dc7633580b"
Commit f056396b authored by rusty1s's avatar rusty1s
Browse files

fixed negative dims

parent d3aabdf3
...@@ -17,9 +17,6 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -17,9 +17,6 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
for (auto i = 0; i < index.dim() - 1; i++) for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i)); CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous(); src = src.contiguous();
torch::Tensor out; torch::Tensor out;
......
...@@ -69,9 +69,6 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -69,9 +69,6 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
for (auto i = 0; i < index.dim() - 1; i++) for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i)); CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous(); src = src.contiguous();
torch::Tensor out; torch::Tensor out;
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
#endif #endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (dim < 0)
dim = other.dim() + dim;
if (src.dim() == 1) if (src.dim() == 1)
for (auto i = 0; i < dim; i++) for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0); src = src.unsqueeze(0);
...@@ -43,6 +41,7 @@ public: ...@@ -43,6 +41,7 @@ public:
Variable index, int64_t dim, Variable index, int64_t dim,
torch::optional<Variable> optional_out, torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim; ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes(); ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim); index = broadcast(index, src, dim);
...@@ -116,6 +115,7 @@ public: ...@@ -116,6 +115,7 @@ public:
Variable index, int64_t dim, Variable index, int64_t dim,
torch::optional<Variable> optional_out, torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim; ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes(); ctx->saved_data["src_shape"] = src.sizes();
...@@ -151,6 +151,7 @@ public: ...@@ -151,6 +151,7 @@ public:
Variable index, int64_t dim, Variable index, int64_t dim,
torch::optional<Variable> optional_out, torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim; ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes(); ctx->saved_data["src_shape"] = src.sizes();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment