"data/git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "7ea81099235fd4ccf8d4b9ba202e76cce40b5cc8"
Unverified Commit 0a2f60ba authored by sherie's avatar sherie Committed by GitHub
Browse files

[Fix] Fix roi_align npu bug (#2862)

parent 43c5c76f
...@@ -7,13 +7,14 @@ void roi_align_forward_npu(Tensor input, Tensor rois, Tensor output, ...@@ -7,13 +7,14 @@ void roi_align_forward_npu(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x, int aligned_height, Tensor argmax_y, Tensor argmax_x, int aligned_height,
int aligned_width, float spatial_scale, int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned) { int sampling_ratio, int pool_mode, bool aligned) {
int64_t roi_end_mode = 2;
if (!aligned) { if (!aligned) {
LOG(WARNING) << "The [aligned] attr in roi_align op is false"; LOG(WARNING) << "The [aligned] attr in roi_align op is false";
roi_end_mode = 0;
} }
int64_t aligned_height_64 = aligned_height; int64_t aligned_height_64 = aligned_height;
int64_t aligned_width_64 = aligned_width; int64_t aligned_width_64 = aligned_width;
int64_t sampling_ratio_64 = sampling_ratio; int64_t sampling_ratio_64 = sampling_ratio;
int64_t roi_end_mode = 0;
OpCommand cmd; OpCommand cmd;
cmd.Name("ROIAlign") cmd.Name("ROIAlign")
.Input(input) .Input(input)
...@@ -35,7 +36,11 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y, ...@@ -35,7 +36,11 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y,
int64_t aligned_height_64 = aligned_height; int64_t aligned_height_64 = aligned_height;
int64_t aligned_width_64 = aligned_width; int64_t aligned_width_64 = aligned_width;
int64_t sampling_ratio_64 = sampling_ratio; int64_t sampling_ratio_64 = sampling_ratio;
int64_t roi_end_mode = 0; int64_t roi_end_mode = 2;
if (!aligned) {
LOG(WARNING) << "The [aligned] attr in roi_align_grad op is false";
roi_end_mode = 0;
}
c10::SmallVector<int64_t, SIZE> xdiff_shape = c10::SmallVector<int64_t, SIZE> xdiff_shape =
at_npu::native::array_to_small_vector(grad_input.sizes()); at_npu::native::array_to_small_vector(grad_input.sizes());
OpCommand cmd; OpCommand cmd;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment