Commit a0939977 authored by ZShaopeng's avatar ZShaopeng Committed by Zaida Zhou
Browse files

[Feature] Support MultiScaleDeformableAttn with cambricon MLU backend

parent 193de43b
......@@ -33,7 +33,7 @@ We implement common ops used in detection, segmentation, etc.
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
......
......@@ -33,7 +33,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MergeCells | | √ | | |
| MinAreaPolygon | | √ | | |
| ModulatedDeformConv2d | √ | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
......
......@@ -362,4 +362,37 @@ __mlu_func__ inline void convertFloat2half(half *dst, float *src,
#endif
}
/*!
* @brief recursiveSumPool.
* @param[in,out] dst
* Pointer to NRAM that stores the input and output data.
* @param[in] low_dim
* Which is the number of low dim.
* @param[in] high_dim
* Which is the number of high dim.
* @param[in] kernel_limit
* Which is the high_dim of sumpool per time.
******************************************************************************/
template <typename T>
__mlu_func__ void recursiveSumPool(T *dst, int low_dim, int high_dim,
int kernel_limit) {
for (; high_dim > 1;) {
int repeat_s = high_dim / kernel_limit;
int remain_s = high_dim % kernel_limit;
if (remain_s) {
__bang_sumpool((T *)dst, (T *)dst, low_dim, 1, remain_s, 1, remain_s, 1,
1);
}
if (repeat_s) {
__bang_sumpool((T *)dst + (remain_s > 0 ? low_dim : 0),
(T *)dst + remain_s * low_dim, low_dim,
kernel_limit * repeat_s, 1, kernel_limit, 1, 1,
kernel_limit);
}
high_dim = repeat_s + (bool)remain_s;
}
return;
}
#endif // COMMON_MLU_HELPER_HPP_
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
#include <math.h>
/****************************************************************************************
*
* NRAM partition forward:
* | spatial_shapes | data_value_p1_ping | data_value_p2_ping |
* | data_value_p3_ping | data_value_p4_ping | data_col_ping |
* | data_value_p1_pong | data_value_p2_pong | data_value_p3_pong |
* | data_value_p4_pong | data_col_pong | auxiliary_a |
* | auxiliary_b |
* | 128bytes | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size |
*
****************************************************************************************/
/****************************************************************************************
*
* NRAM partition backward:
* | grad_output_nram | grad_output_nram_temp | grad_weight |
* | grad_h_weight | grad_w_weight | top_grad |
* | top_grad_temp | spatial_shapes_nram | sampling_loc_nram |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | deal_size |
* | deal_size | deal_size | 64bytes |
*
****************************************************************************************/
#define TWELVE_SPLIT 12
#define ALIGN_NUM 64
#define ALIGN_NUM_FOR_REDUCE 32
__nram__ char nram_buffer[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void loadNeighborPointsData(
const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram,
T *data_value_p3_nram, T *data_value_p4_nram, const size_t deal_num,
const int32_t &width, const int32_t &height, const int32_t &num_heads,
const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = width * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
const int32_t base_ptr_offset = head_idx * channels;
// top-left point
if (h_low >= 0 && w_low >= 0) {
const int32_t v1_offset =
h_low_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p1_nram, data_value_gdram + v1_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
const int32_t v2_offset =
h_low_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p2_nram, data_value_gdram + v2_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
const int32_t v3_offset =
h_high_ptr_offset + w_low_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p3_nram, data_value_gdram + v3_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
const int32_t v4_offset =
h_high_ptr_offset + w_high_ptr_offset + base_ptr_offset;
__memcpy_async(data_value_p4_nram, data_value_gdram + v4_offset,
deal_num * sizeof(T), GDRAM2NRAM);
}
}
template <typename T>
__mlu_func__ void bilinearInterpolation(
T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram,
T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b,
const size_t deal_num, const int32_t &width, const int32_t &height,
const T &x, const T &y) {
const int32_t w_low = floorf(x);
const int32_t h_low = floorf(y);
const int32_t w_high = w_low + 1;
const int32_t h_high = h_low + 1;
const T lw = x - w_low;
const T lh = y - h_low;
const T hw = 1 - lw;
const T hh = 1 - lh;
const T w1 = hh * hw;
const T w2 = hh * lw;
const T w3 = lh * hw;
const T w4 = lh * lw;
__bang_write_value((T *)sample_point_value, deal_num, (T)0);
// top-left point
if (h_low >= 0 && w_low >= 0) {
// sample_point_value += v1 * w1
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p1_nram, (T)w1,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// top-right point
if (h_low >= 0 && w_high <= width - 1) {
// sample_point_value += v2 * w2
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p2_nram, (T)w2,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-left point
if (h_high <= height - 1 && w_low >= 0) {
// sample_point_value += v3 * w3
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p3_nram, (T)w3,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
// bottom-right point
if (h_high <= height - 1 && w_high <= width - 1) {
// sample_point_value += v4 * w4
__bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p4_nram, (T)w4,
deal_num);
__bang_add((T *)sample_point_value, (T *)sample_point_value,
(T *)auxiliary_b, deal_num);
}
}
template <typename T>
__mlu_global__ void MLUKernelMsDeformAttnForward(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
if (coreId == 0x80) {
return;
}
const size_t spatial_size = PAD_UP(2 * sizeof(int32_t), NFU_ALIGN_SIZE);
const size_t span_num_deal =
PAD_DOWN((MAX_NRAM_SIZE - spatial_size) / TWELVE_SPLIT / sizeof(T),
NFU_ALIGN_SIZE);
const size_t align_num = NFU_ALIGN_SIZE;
const int32_t channels_seg_num = channels / span_num_deal;
const size_t channels_rem = channels % span_num_deal;
const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num);
char *data_spatial_shapes_nram = nram_buffer;
char *ping_data_value_p1_nram = data_spatial_shapes_nram + spatial_size;
char *ping_data_value_p2_nram =
ping_data_value_p1_nram + span_num_deal * sizeof(T);
char *ping_data_value_p3_nram =
ping_data_value_p2_nram + span_num_deal * sizeof(T);
char *ping_data_value_p4_nram =
ping_data_value_p3_nram + span_num_deal * sizeof(T);
char *ping_data_col_nram =
ping_data_value_p4_nram + span_num_deal * sizeof(T);
char *pong_data_value_p1_nram =
ping_data_col_nram + span_num_deal * sizeof(T);
char *pong_data_value_p2_nram =
pong_data_value_p1_nram + span_num_deal * sizeof(T);
char *pong_data_value_p3_nram =
pong_data_value_p2_nram + span_num_deal * sizeof(T);
char *pong_data_value_p4_nram =
pong_data_value_p3_nram + span_num_deal * sizeof(T);
char *pong_data_col_nram =
pong_data_value_p4_nram + span_num_deal * sizeof(T);
char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T);
char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T);
const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T);
size_t data_col_ping_pong_idx = 0;
int32_t block_num_per_core = (batch_size * num_queries * num_heads) / taskDim;
const int32_t block_num_rem =
(batch_size * num_queries * num_heads) % taskDim;
const int32_t idx_start = taskId < (block_num_rem + 1)
? taskId * (block_num_per_core + 1)
: taskId * block_num_per_core + block_num_rem;
block_num_per_core =
taskId < block_num_rem
? (batch_size * num_queries * num_heads) / taskDim + 1
: (batch_size * num_queries * num_heads) / taskDim;
for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core;
++cur_idx) {
// cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads +
// head_idx
const int32_t head_idx = cur_idx % num_heads;
const int32_t batch_idx = (cur_idx / num_heads) / num_queries;
const char *data_value_gdram_start =
data_value_gdram +
batch_idx * num_keys * num_heads * channels * sizeof(T);
const char *data_sampling_loc_gdram_start =
data_sampling_loc_gdram +
cur_idx * num_levels * num_points * 2 * sizeof(T);
const char *data_attn_weight_gdram_start =
data_attn_weight_gdram + cur_idx * num_levels * num_points * sizeof(T);
char *data_col_gdram_start =
data_col_gdram + cur_idx * channels * sizeof(T);
for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
span_num_deal, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
c_seg_idx * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_h_next_point = spatial_h;
spatial_w_next_point = spatial_w;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
span_num_deal, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, span_num_deal, spatial_w,
spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
span_num_deal);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, span_num_deal);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
span_num_deal * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
if (channels_rem > 0) {
__bang_write_value(
(T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap),
channels_align_rem, (T)0);
// load data
// level_idx = 0, point_idx = 0
__memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram,
2 * sizeof(int32_t), GDRAM2NRAM);
int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0];
int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1];
const char *data_value_ptr =
data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T);
T loc_w = ((T *)data_sampling_loc_gdram_start)[0];
T loc_h = ((T *)data_sampling_loc_gdram_start)[1];
T weight = ((T *)data_attn_weight_gdram_start)[0];
T x = loc_w * spatial_w - 0.5;
T y = loc_h * spatial_h - 0.5;
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr, (T *)ping_data_value_p1_nram,
(T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram,
(T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h,
num_heads, channels, x, y, head_idx);
}
T spatial_h_next_point = 0;
T spatial_w_next_point = 0;
T weight_next_point = 0;
T x_next_point = 0;
T y_next_point = 0;
__asm__ volatile("sync;");
for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) {
for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) {
// load data
if (point_idx == num_points - 1 && level_idx == num_levels - 1) {
// last point no need to load data, continue to compute
} else if (point_idx == num_points - 1) {
const int32_t level_start_id =
((int32_t *)data_level_start_index_gdram)[level_idx + 1];
const int32_t spatial_h_ptr = (level_idx + 1) << 1;
__memcpy(
data_spatial_shapes_nram,
data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t),
2 * sizeof(int32_t), GDRAM2NRAM);
spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0];
spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1];
data_value_ptr = data_value_gdram_start +
(level_start_id * num_heads * channels +
channels_seg_num * span_num_deal) *
sizeof(T);
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w_next_point - 0.5;
y_next_point = loc_h * spatial_h_next_point - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h_next_point &&
x_next_point < spatial_w_next_point) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w_next_point, spatial_h_next_point,
num_heads, channels, x_next_point, y_next_point, head_idx);
}
} else {
spatial_w_next_point = spatial_w;
spatial_h_next_point = spatial_h;
loc_w = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2];
loc_h = ((T *)data_sampling_loc_gdram_start)
[(level_idx * num_points + point_idx + 1) * 2 + 1];
weight_next_point =
((T *)data_attn_weight_gdram_start)[level_idx * num_points +
point_idx + 1];
x_next_point = loc_w * spatial_w - 0.5;
y_next_point = loc_h * spatial_h - 0.5;
if (y_next_point > -1 && x_next_point > -1 &&
y_next_point < spatial_h && x_next_point < spatial_w) {
loadNeighborPointsData(
(T *)data_value_ptr,
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx + 1) % 2) *
ping_pong_gap),
channels_rem, spatial_w, spatial_h, num_heads, channels,
x_next_point, y_next_point, head_idx);
}
}
// compute
if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) {
bilinearInterpolation(
(T *)(ping_data_value_p1_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p2_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p3_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)(ping_data_value_p4_nram +
((level_idx * num_points + point_idx) % 2) *
ping_pong_gap),
(T *)auxiliary_a, (T *)auxiliary_b, channels_align_rem,
spatial_w, spatial_h, x, y);
__bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight,
channels_align_rem);
__bang_add((T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)(ping_data_col_nram +
data_col_ping_pong_idx * ping_pong_gap),
(T *)auxiliary_a, channels_align_rem);
}
spatial_w = spatial_w_next_point;
spatial_h = spatial_h_next_point;
weight = weight_next_point;
x = x_next_point;
y = y_next_point;
__asm__ volatile("sync;");
}
}
// store
__memcpy_async(
data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T),
ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap,
channels_rem * sizeof(T), NRAM2GDRAM);
data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2;
}
}
__asm__ volatile("sync;");
return;
}
template __mlu_global__ void MLUKernelMsDeformAttnForward<float>(
const char *data_value_gdram, const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram);
void KernelMsDeformAttnForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char *data_value_gdram,
const char *data_spatial_shapes_gdram,
const char *data_level_start_index_gdram,
const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char *data_col_gdram) {
MLUKernelMsDeformAttnForward<float><<<k_dim, k_type, queue>>>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points, data_col_gdram);
}
template <typename T>
void __mlu_func__ msDeformAttnCol2imBilinear(
T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1,
const T &w2, const T &w3, const T &w4, const int32_t &h_low,
const int32_t &w_low, const int32_t &h_high, const int32_t &w_high,
const int32_t &base_ptr, const int32_t &h_low_ptr_offset,
const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset,
const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh,
const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight,
T *grad_w_weight, T *grad_value, T *grad_output_nram, T *grad_weight,
T *grad_sampling_loc, T *grad_attn_weight, T *grad_output_nram_temp,
const int32_t &deal_num, const int32_t &deal_num_real,
const T *data_value_ptr) {
if (h_low >= 0 && w_low >= 0) {
int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram, data_value_ptr + offset1,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1),
(T *)top_grad_temp, deal_num_real);
}
if (h_low >= 0 && w_high <= width - 1) {
int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset2,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num);
__bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num);
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2,
deal_num);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_low >= 0) {
int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset3,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num);
__bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3,
deal_num);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3),
(T *)top_grad_temp, deal_num_real);
}
if (h_high <= height - 1 && w_high <= width - 1) {
int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
__memcpy(grad_output_nram_temp, data_value_ptr + offset4,
deal_num_real * sizeof(T), GDRAM2NRAM);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num);
__bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num);
__bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num);
__bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num);
// for calc grad_attn_weight
__bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4,
deal_num);
__bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp,
deal_num);
__bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4),
(T *)top_grad_temp, deal_num_real);
}
__bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
const int32_t align_num_on_200 = NFU_ALIGN_SIZE / sizeof(float);
recursiveSumPool(grad_output_nram, align_num_on_200,
deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_output_nram, grad_output_nram,
NFU_ALIGN_SIZE / sizeof(float));
#endif
__bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight,
(T *)grad_output_nram, 1);
__bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num);
__bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num);
__bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_w_weight, grad_w_weight,
NFU_ALIGN_SIZE / sizeof(float));
#endif
__bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc),
(T *)grad_w_weight, 1);
__bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num);
__bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num);
#if __BANG_ARCH__ >= 322
recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE);
#else
recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200,
ALIGN_NUM_FOR_REDUCE);
__bang_reduce_sum(grad_h_weight, grad_h_weight,
NFU_ALIGN_SIZE / sizeof(float));
#endif
__bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1),
(T *)grad_h_weight, 1);
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight) {
if (coreId == 0x80) {
return;
}
const int32_t split_num = 8;
const int32_t spatial_shapes_size = 64;
int32_t deal_num = PAD_DOWN(
(MAX_NRAM_SIZE - spatial_shapes_size) / split_num / sizeof(float),
ALIGN_NUM);
float *grad_output_nram = (float *)nram_buffer;
float *grad_output_nram_temp = (float *)nram_buffer + deal_num;
float *grad_weight = (float *)nram_buffer + 2 * deal_num;
float *grad_h_weight = (float *)nram_buffer + 3 * deal_num;
float *grad_w_weight = (float *)nram_buffer + 4 * deal_num;
float *top_grad = (float *)nram_buffer + 5 * deal_num;
float *top_grad_temp = (float *)nram_buffer + 6 * deal_num;
int32_t *spatial_shapes_nram =
(int32_t *)((float *)nram_buffer + 7 * deal_num);
float *sampling_loc_nram =
(float *)nram_buffer + 7 * deal_num + 2 * sizeof(int32_t);
const int32_t total_num = batch * num_query * num_heads * num_levels;
int32_t num_per_core = total_num / taskDim;
int32_t num_rem = total_num % taskDim;
num_per_core = num_per_core + int32_t(taskId < num_rem);
int32_t start_per_core =
num_rem > taskId
? (taskId * num_per_core)
: ((num_per_core + 1) * num_rem + (taskId - num_rem) * num_per_core);
int32_t end_per_core = start_per_core + num_per_core;
const int32_t C_repeat = channels / deal_num;
const int32_t C_tail = channels % deal_num;
const int32_t qid_stride = num_heads * channels;
int32_t base_ptr = 0;
for (int32_t num_loop = start_per_core; num_loop < end_per_core; ++num_loop) {
const int32_t l_col = num_loop % num_levels;
const int32_t m_col = num_loop / num_levels % num_heads;
const int32_t q_col = num_loop / num_levels / num_heads % num_query;
const int32_t b_col = num_loop / num_query / num_heads / num_levels;
int32_t data_weight_ptr = num_loop * num_points;
int32_t data_loc_w_ptr = data_weight_ptr << 1;
const int32_t value_offset = b_col * spatial_size * num_heads * channels;
const int32_t level_start_id = data_level_start_index[l_col];
int32_t spatial_h_ptr = l_col << 1;
int32_t grad_output_offset = b_col * num_query * num_heads * channels +
q_col * num_heads * channels +
m_col * channels;
__memcpy(spatial_shapes_nram, spatial_shapes + spatial_h_ptr,
2 * sizeof(int32_t), GDRAM2NRAM);
const int32_t spatial_h = spatial_shapes_nram[0];
const int32_t spatial_w = spatial_shapes_nram[1];
const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride;
const float *data_value_ptr = data_value + value_ptr_offset;
float *grad_value_ptr = grad_value + value_ptr_offset;
const int32_t grad_attn_weight_out = num_loop * num_points;
const int32_t grad_sampling_loc_out = num_loop * num_points * 2;
for (int32_t p_col = 0; p_col < num_points; ++p_col) {
__memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr,
2 * sizeof(float), GDRAM2NRAM);
const float loc_w = sampling_loc_nram[0];
const float loc_h = sampling_loc_nram[1];
const float weight = data_attn_weight[data_weight_ptr];
const float h_im = loc_h * spatial_h - 0.5;
const float w_im = loc_w * spatial_w - 0.5;
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
const int32_t h_low = floorf(h_im);
const int32_t w_low = floorf(w_im);
const int32_t h_high = h_low + 1;
const int32_t w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1.0 - lh;
const float hw = 1.0 - lw;
const int32_t w_stride = num_heads * channels;
const int32_t h_stride = spatial_w * w_stride;
const int32_t h_low_ptr_offset = h_low * h_stride;
const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int32_t w_low_ptr_offset = w_low * w_stride;
const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride;
float w1 = hh * hw;
float w2 = hh * lw;
float w3 = lh * hw;
float w4 = lh * lw;
for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) {
base_ptr = m_col * channels + C_loop * deal_num;
__bang_write_zero(grad_weight, 3 * deal_num);
__bang_write_zero(grad_output_nram, deal_num);
__memcpy(top_grad,
grad_output + grad_output_offset + C_loop * deal_num,
deal_num * sizeof(float), GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, deal_num, data_value_ptr);
}
if (C_tail != 0) {
base_ptr = m_col * channels + C_repeat * deal_num;
__bang_write_zero(grad_output_nram, 8 * deal_num);
__memcpy(top_grad,
grad_output + grad_output_offset + C_repeat * deal_num,
C_tail * sizeof(float), GDRAM2NRAM);
msDeformAttnCol2imBilinear(
top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low,
h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset,
h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad,
weight, grad_h_weight, grad_w_weight, grad_value_ptr,
grad_output_nram, grad_weight,
grad_sampling_loc + grad_sampling_loc_out + p_col * 2,
grad_attn_weight + grad_attn_weight_out + p_col,
grad_output_nram_temp, deal_num, C_tail, data_value_ptr);
}
}
data_weight_ptr += 1;
data_loc_w_ptr += 2;
}
}
}
__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward(
const float *data_value, const int32_t *spatial_shapes,
const int32_t *data_level_start_index, const float *data_sampling_loc,
const float *data_attn_weight, const float *grad_output,
const int32_t batch, const int32_t spatial_size, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float *grad_value, float *grad_sampling_loc,
float *grad_attn_weight);
void KernelMsDeformAttnBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float *data_value,
const int32_t *spatial_shapes, const int32_t *data_level_start_index,
const float *data_sampling_loc, const float *data_attn_weight,
const float *grad_output, const int32_t batch, const int32_t spatial_size,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_query, const int32_t num_points, float *grad_value,
float *grad_sampling_loc, float *grad_attn_weight) {
MLUUnion1KernelMsDeformAttnBackward<<<k_dim, k_type, queue>>>(
data_value, spatial_shapes, data_level_start_index, data_sampling_loc,
data_attn_weight, grad_output, batch, spatial_size, num_heads, channels,
num_levels, num_query, num_points, grad_value, grad_sampling_loc,
grad_attn_weight);
}
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
void KernelMsDeformAttnForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const char* data_value_gdram,
const char* data_spatial_shapes_gdram,
const char* data_level_start_index_gdram,
const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram,
const int32_t batch_size, const int32_t num_keys, const int32_t num_heads,
const int32_t channels, const int32_t num_levels, const int32_t num_queries,
const int32_t num_points, char* data_col_gdram);
void KernelMsDeformAttnBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const cnrtDataType_t d_type, const float* data_value,
const int32_t* spatial_shapes, const int32_t* data_level_start_index,
const float* data_sampling_loc, const float* data_attn_weight,
const float* grad_output, const int32_t batch_size, const int32_t num_keys,
const int32_t num_heads, const int32_t channels, const int32_t num_levels,
const int32_t num_queries, const int32_t num_points, float* grad_value,
float* grad_sampling_loc, float* grad_attn_weight);
// policy function
static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type,
const int batch_size, const int num_queries,
const int num_heads) {
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y =
MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x,
torch_mlu::getDeviceAttr(cnrtAttrClusterCount));
k_dim->z = 1;
#if __BANG_ARCH__ == 520
*k_type = CNRT_FUNC_TYPE_BLOCK;
#else
*k_type = CNRT_FUNC_TYPE_UNION1;
#endif
}
// policy function for backward
static void policyFuncBackward(const int32_t batch_size,
const int32_t num_queries,
const int32_t num_heads,
const int32_t num_levels,
cnrtFunctionType_t* k_type, cnrtDim3_t* k_dim) {
size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->x = core_limit;
int32_t total_num = batch_size * num_queries * num_heads * num_levels;
size_t total_num_align = CEIL_ALIGN(total_num, core_limit);
k_dim->y = (total_num_align / core_limit) > cluster_limit
? cluster_limit
: (total_num_align / core_limit);
k_dim->z = 1;
*k_type = CNRT_FUNC_TYPE_UNION1;
}
Tensor ms_deform_attn_mlu_forward(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,
const Tensor& sampling_loc,
const Tensor& attn_weight,
const int im2col_step) {
// check contiguous
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
// check datatype
TORCH_CHECK((value.scalar_type() == at::kFloat),
"value type should be Float, got ", value.scalar_type(), ".");
TORCH_CHECK((spatial_shapes.scalar_type() == at::kInt ||
spatial_shapes.scalar_type() == at::kLong),
"spatial_shapes type should be Int, got ",
spatial_shapes.scalar_type(), ".");
TORCH_CHECK((level_start_index.scalar_type() == at::kInt ||
level_start_index.scalar_type() == at::kLong),
"level_start_index type should be Int, got ",
level_start_index.scalar_type(), ".");
TORCH_CHECK((sampling_loc.scalar_type() == at::kFloat),
"sampling_loc type should be Float, got ",
sampling_loc.scalar_type(), ".");
TORCH_CHECK((attn_weight.scalar_type() == at::kFloat),
"attn_weight type should be Float, got ",
attn_weight.scalar_type(), ".");
// check shape
TORCH_CHECK(value.dim() == 4, "value should be a 4d tensor, got ",
value.dim(), "D.");
TORCH_CHECK(spatial_shapes.dim() == 2,
"spatial_shapes should be a 2d tensor, got ",
spatial_shapes.dim(), "D.");
TORCH_CHECK(level_start_index.dim() == 1,
"level_start_index should be a 1d tensor, got ",
level_start_index.dim(), "D.");
TORCH_CHECK(sampling_loc.dim() == 6,
"sampling_loc should be a 6d tensor, got ", sampling_loc.dim(),
"D.");
TORCH_CHECK(attn_weight.dim() == 5, "attn_weight should be a 5d tensor, got ",
attn_weight.dim(), "D.");
const int batch_size = value.size(0);
const int num_keys = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_queries = sampling_loc.size(1);
const int num_points = sampling_loc.size(4);
TORCH_CHECK(spatial_shapes.size(1) == 2,
"the 2nd dimensions of spatial_shapes should be 2, got ",
spatial_shapes.size(1), ".");
TORCH_CHECK(sampling_loc.size(5) == 2,
"the 6th dimensions of sampling_loc should be 2, got ",
sampling_loc.size(5), ".");
TORCH_CHECK((sampling_loc.size(0) == batch_size),
"the 1st dimensions of sampling_loc should be batch_size, ",
"but now the 1st dimension of sampling_loc is ",
sampling_loc.size(0), ", and batch_size is ", batch_size, ".");
TORCH_CHECK((attn_weight.size(0) == batch_size),
"the 1st dimensions of attn_weight should be batch_size, ",
"but now the 1st dimension of attn_weight is ",
attn_weight.size(0), ", and batch_size is ", batch_size, ".");
TORCH_CHECK((sampling_loc.size(2) == num_heads),
"the 3rd dimensions of sampling_loc should be num_heads, ",
"but now the 3rd dimension of sampling_loc is ",
sampling_loc.size(2), ", and num_heads is ", num_heads, ".");
TORCH_CHECK((attn_weight.size(2) == num_heads),
"the 3rd dimensions of attn_weight should be num_heads, ",
"but now the 3rd dimension of attn_weight is ",
attn_weight.size(2), ", and num_heads is ", num_heads, ".");
TORCH_CHECK((level_start_index.size(0) == num_levels),
"the 1st dimensions of level_start_index should be num_levels, ",
"but now the 1st dimension of level_start_index is ",
level_start_index.size(0), ", and num_levels is ", num_levels,
".");
TORCH_CHECK((sampling_loc.size(3) == num_levels),
"the 4th dimensions of sampling_loc should be num_levels, ",
"but now the 4th dimension of sampling_loc is ",
sampling_loc.size(3), ", and num_levels is ", num_levels, ".");
TORCH_CHECK((attn_weight.size(3) == num_levels),
"the 4th dimensions of attn_weight should be num_levels, ",
"but now the 4th dimension of attn_weight is ",
attn_weight.size(3), ", and num_levels is ", num_levels, ".");
TORCH_CHECK((attn_weight.size(1) == num_queries),
"the 2nd dimensions of attn_weight should be num_queries, ",
"but now the 2nd dimension of attn_weight is ",
attn_weight.size(1), ", and num_queries is ", num_queries, ".");
TORCH_CHECK((attn_weight.size(4) == num_points),
"the 5th dimensions of attn_weight should be num_points, ",
"but now the 5th dimension of attn_weight is ",
attn_weight.size(4), ", and num_points is ", num_points, ".");
auto output = at::zeros({batch_size, num_queries, num_heads, channels},
value.options());
// large tensor check
const size_t max_input_size = 2147483648;
TORCH_CHECK(value.numel() < max_input_size,
"value element num should be less than 2^31, got ", value.numel(),
".");
TORCH_CHECK(sampling_loc.numel() < max_input_size,
"sampling_loc element num should be less than 2^31, got ",
sampling_loc.numel(), ".");
TORCH_CHECK(output.numel() < max_input_size,
"output element num should be less than 2^31, got ",
output.numel(), ".");
// check zero element
TORCH_CHECK(batch_size != 0, "batch_size should not be zero");
TORCH_CHECK(num_heads != 0, "num_heads should not be zero");
TORCH_CHECK(channels != 0, "channels should not be zero");
TORCH_CHECK(num_queries != 0, "num_queries should not be zero");
if (num_keys == 0 || num_levels == 0 || num_points == 0) {
return output;
}
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads);
// get compute queue
auto queue = torch_mlu::getCurQueue();
auto spatial_shapes_ = spatial_shapes.to(at::kInt);
auto level_start_index_ = level_start_index.to(at::kInt);
// get ptr of tensors
auto value_impl = torch_mlu::getMluTensorImpl(value);
auto value_ptr = value_impl->cnnlMalloc();
auto spatial_shapes_impl = torch_mlu::getMluTensorImpl(spatial_shapes_);
auto spatial_shapes_ptr = spatial_shapes_impl->cnnlMalloc();
auto level_start_index_impl = torch_mlu::getMluTensorImpl(level_start_index_);
auto level_start_index_ptr = level_start_index_impl->cnnlMalloc();
auto sampling_loc_impl = torch_mlu::getMluTensorImpl(sampling_loc);
auto sampling_loc_ptr = sampling_loc_impl->cnnlMalloc();
auto attn_weight_impl = torch_mlu::getMluTensorImpl(attn_weight);
auto attn_weight_ptr = attn_weight_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output);
auto output_ptr = output_impl->cnnlMalloc();
// get compute dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnForward(
k_dim, k_type, queue, data_type, (char*)value_ptr,
(char*)spatial_shapes_ptr, (char*)level_start_index_ptr,
(char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys,
num_heads, channels, num_levels, num_queries, num_points,
(char*)output_ptr);
output = output.view({batch_size, num_queries, num_heads * channels});
return output;
}
void ms_deform_attn_mlu_backward(
const Tensor& value, const Tensor& spatial_shapes,
const Tensor& level_start_index, const Tensor& sampling_loc,
const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value,
Tensor& grad_sampling_loc, Tensor& grad_attn_weight,
const int im2col_step) {
// check contiguous
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
// check datatype
TORCH_CHECK((value.scalar_type() == at::kFloat),
"value type should be Float, got ", value.scalar_type(), ".");
TORCH_CHECK((spatial_shapes.scalar_type() == at::kInt ||
spatial_shapes.scalar_type() == at::kLong),
"spatial_shapes type should be Int, got ",
spatial_shapes.scalar_type(), ".");
TORCH_CHECK((level_start_index.scalar_type() == at::kInt ||
level_start_index.scalar_type() == at::kLong),
"level_start_index type should be Int, got ",
level_start_index.scalar_type(), ".");
TORCH_CHECK((sampling_loc.scalar_type() == at::kFloat),
"sampling_loc type should be Float, got ",
sampling_loc.scalar_type(), ".");
TORCH_CHECK((attn_weight.scalar_type() == at::kFloat),
"attn_weight type should be Float, got ",
attn_weight.scalar_type(), ".");
TORCH_CHECK((grad_output.scalar_type() == at::kFloat),
"grad_output type should be Float, got ",
grad_output.scalar_type(), ".");
const int batch_size = value.size(0);
const int num_keys = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_queries = sampling_loc.size(1);
const int num_points = sampling_loc.size(4);
// Check shape.
TORCH_CHECK(spatial_shapes.size(1) == 2,
"the 2nd dimensions of spatial_shapes should be 2, got ",
spatial_shapes.size(1), ".");
TORCH_CHECK((level_start_index.size(0) == num_levels),
"the 1st dimensions of level_start_index should be num_levels, ",
"but now the 1st dimension of level_start_index is ",
level_start_index.size(0), ", and num_levels is ", num_levels,
".");
TORCH_CHECK((sampling_loc.size(0) == batch_size),
"the 1st dimensions of sampling_loc should be batch_size, ",
"but now the 1st dimension of sampling_loc is ",
sampling_loc.size(0), ", and batch_size is ", batch_size, ".");
TORCH_CHECK((sampling_loc.size(2) == num_heads),
"the 3rd dimensions of sampling_loc should be num_heads, ",
"but now the 3rd dimension of sampling_loc is ",
sampling_loc.size(2), ", and num_heads is ", num_heads, ".");
TORCH_CHECK((sampling_loc.size(3) == num_levels),
"the 4th dimensions of sampling_loc should be num_levels, ",
"but now the 4th dimension of sampling_loc is ",
sampling_loc.size(3), ", and num_levels is ", num_levels, ".");
TORCH_CHECK(sampling_loc.size(5) == 2,
"the 6th dimensions of sampling_loc should be 2, got ",
sampling_loc.size(5), ".");
TORCH_CHECK((attn_weight.size(0) == batch_size),
"the 1st dimensions of attn_weight should be batch_size, ",
"but now the 1st dimension of attn_weight is ",
attn_weight.size(0), ", and batch_size is ", batch_size, ".");
TORCH_CHECK((attn_weight.size(1) == num_queries),
"the 2nd dimensions of attn_weight should be num_queries, ",
"but now the 2nd dimension of attn_weight is ",
attn_weight.size(1), ", and num_queries is ", num_queries, ".");
TORCH_CHECK((attn_weight.size(2) == num_heads),
"the 3rd dimensions of attn_weight should be num_heads, ",
"but now the 3rd dimension of attn_weight is ",
attn_weight.size(2), ", and num_heads is ", num_heads, ".");
TORCH_CHECK((attn_weight.size(3) == num_levels),
"the 4th dimensions of attn_weight should be num_levels, ",
"but now the 4th dimension of attn_weight is ",
attn_weight.size(3), ", and num_levels is ", num_levels, ".");
TORCH_CHECK((attn_weight.size(4) == num_points),
"the 5th dimensions of attn_weight should be num_points, ",
"but now the 5th dimension of attn_weight is ",
attn_weight.size(4), ", and num_points is ", num_points, ".");
TORCH_CHECK((grad_output.size(0) == batch_size),
"the 1st dimensions of grad_output should be batch_size, ",
"but now the 1st dimension of grad_output is ",
grad_output.size(0), ", and batch_size is ", batch_size, ".");
TORCH_CHECK((grad_output.size(1) == num_queries),
"the 2nd dimensions of grad_output should be num_queries, ",
"but now the 2nd dimension of grad_output is ",
grad_output.size(1), ", and num_queries is ", num_queries, ".");
TORCH_CHECK(
(grad_output.size(2) == num_heads * channels),
"the 3rd dimensions of grad_output should be num_heads * channels, ",
"but now the 3rd dimension of grad_output is ", grad_output.size(2),
", and num_heads * channels is ", num_heads * channels, ".");
// check zero element
TORCH_CHECK(batch_size != 0, "The batch_size is zero.");
TORCH_CHECK(channels != 0, "The channels is zero.");
TORCH_CHECK(num_keys != 0, "The num_keys is zero.");
TORCH_CHECK(num_heads != 0, "The num_heads is zero.");
TORCH_CHECK(num_queries != 0, "The num_queries is zero.");
if (num_levels == 0 || num_points == 0) {
return;
}
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(batch_size, num_queries, num_heads, num_levels, &k_type,
&k_dim);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto value_impl = torch_mlu::getMluTensorImpl(value);
auto value_ptr = value_impl->cnnlMalloc();
auto spatial_shapes_impl = torch_mlu::getMluTensorImpl(spatial_shapes);
auto spatial_shapes_ptr = spatial_shapes_impl->cnnlMalloc();
auto level_start_index_impl = torch_mlu::getMluTensorImpl(level_start_index);
auto level_start_index_ptr = level_start_index_impl->cnnlMalloc();
auto sampling_loc_impl = torch_mlu::getMluTensorImpl(sampling_loc);
auto sampling_loc_ptr = sampling_loc_impl->cnnlMalloc();
auto attn_weight_impl = torch_mlu::getMluTensorImpl(attn_weight);
auto attn_weight_ptr = attn_weight_impl->cnnlMalloc();
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output);
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto grad_value_impl = torch_mlu::getMluTensorImpl(grad_value);
auto grad_value_ptr = grad_value_impl->cnnlMalloc();
auto grad_sampling_loc_impl = torch_mlu::getMluTensorImpl(grad_sampling_loc);
auto grad_sampling_loc_ptr = grad_sampling_loc_impl->cnnlMalloc();
auto grad_attn_weight_impl = torch_mlu::getMluTensorImpl(grad_attn_weight);
auto grad_attn_weight_ptr = grad_attn_weight_impl->cnnlMalloc();
// get comput dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x
<< ", " << k_dim.y << ", " << k_dim.z << ">>>";
KernelMsDeformAttnBackward(
k_dim, k_type, queue, data_type, (float*)value_ptr,
(int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr,
(float*)sampling_loc_ptr, (float*)attn_weight_ptr,
(float*)grad_output_ptr, batch_size, num_keys, num_heads, channels,
num_levels, num_queries, num_points, (float*)grad_value_ptr,
(float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr);
}
Tensor ms_deform_attn_impl_forward(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,
const Tensor& sampling_loc,
const Tensor& attn_weight,
const int im2col_step);
void ms_deform_attn_impl_backward(
const Tensor& value, const Tensor& spatial_shapes,
const Tensor& level_start_index, const Tensor& sampling_loc,
const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value,
Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step);
REGISTER_DEVICE_IMPL(ms_deform_attn_impl_forward, MLU,
ms_deform_attn_mlu_forward);
REGISTER_DEVICE_IMPL(ms_deform_attn_impl_backward, MLU,
ms_deform_attn_mlu_backward);
......@@ -12,6 +12,7 @@ from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -26,7 +27,7 @@ class MultiScaleDeformableAttnFunction(Function):
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
im2col_step: torch.Tensor) -> torch.Tensor:
"""GPU version of multi-scale deformable attention.
"""GPU/MLU version of multi-scale deformable attention.
Args:
value (torch.Tensor): The value has shape
......@@ -63,7 +64,7 @@ class MultiScaleDeformableAttnFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
"""GPU version of backward function.
"""GPU/MLU version of backward function.
Args:
grad_output (torch.Tensor): Gradient of output tensor of forward.
......@@ -346,7 +347,8 @@ class MultiScaleDeformableAttention(BaseModule):
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)):
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
......
......@@ -5,6 +5,7 @@ import torch
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
_USING_PARROTS = True
try:
......@@ -14,22 +15,25 @@ except ImportError:
_USING_PARROTS = False
@pytest.mark.parametrize('device_type', [
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda:0',
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support'))
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_multiscale_deformable_attention(device_type):
def test_multiscale_deformable_attention(device):
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
device = torch.device(device_type)
device = torch.device(device)
msda = MultiScaleDeformableAttention(
embed_dims=3, num_levels=2, num_heads=3)
msda.init_weights()
......@@ -70,20 +74,19 @@ def test_forward_multi_scale_deformable_attn_pytorch():
attention_weights.double()).detach()
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support')
def test_forward_equal_with_pytorch_double():
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
......@@ -93,8 +96,9 @@ def test_forward_equal_with_pytorch_double():
attention_weights.double()).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value.double(), shapes, level_start_index, sampling_locations.double(),
attention_weights.double(), im2col_step).detach().cpu()
value.cuda().double(), shapes.cuda(), level_start_index.cuda(),
sampling_locations.cuda().double(),
attention_weights.cuda().double(), im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
......@@ -103,20 +107,28 @@ def test_forward_equal_with_pytorch_double():
assert max_rel_err < 1e-15
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_forward_equal_with_pytorch_float():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_forward_equal_with_pytorch_float(device):
N, M, D = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
torch.manual_seed(3)
value = torch.rand(N, S, M, D).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
value = torch.rand(N, S, M, D) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
......@@ -124,19 +136,37 @@ def test_forward_equal_with_pytorch_float():
output_pytorch = multi_scale_deformable_attn_pytorch(
value, shapes, sampling_locations, attention_weights).detach().cpu()
output_cuda = MultiScaleDeformableAttnFunction.apply(
value, shapes, level_start_index, sampling_locations,
attention_weights, im2col_step).detach().cpu()
assert torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_cuda - output_pytorch).abs().max()
max_rel_err = ((output_cuda - output_pytorch).abs() /
output_device = MultiScaleDeformableAttnFunction.apply(
value.to(device), shapes.to(device), level_start_index.to(device),
sampling_locations.to(device), attention_weights.to(device),
im2col_step).detach().cpu()
assert torch.allclose(output_device, output_pytorch, rtol=1e-2, atol=1e-3)
max_abs_err = (output_device - output_pytorch).abs().max()
max_rel_err = ((output_device - output_pytorch).abs() /
output_pytorch.abs()).max()
assert max_abs_err < 1e-9
assert max_rel_err < 1e-6
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
])
@pytest.mark.parametrize('channels', [
4,
30,
......@@ -146,20 +176,22 @@ def test_forward_equal_with_pytorch_float():
1025,
])
def test_gradient_numerical(channels,
device,
dtype,
grad_value=True,
grad_sampling_loc=True,
grad_attn_weight=True):
N, M, _ = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).to(device)
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum((H * W).item() for H, W in shapes)
value = torch.rand(N, S, M, channels).cuda() * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
value = torch.rand(N, S, M, channels).to(device) * 0.01
sampling_locations = torch.rand(N, Lq, M, L, P, 2).to(device)
attention_weights = torch.rand(N, Lq, M, L, P).to(device) + 1e-5
attention_weights /= attention_weights.sum(
-1, keepdim=True).sum(
-2, keepdim=True)
......@@ -170,13 +202,23 @@ def test_gradient_numerical(channels,
value.requires_grad = grad_value
sampling_locations.requires_grad = grad_sampling_loc
attention_weights.requires_grad = grad_attn_weight
if device == 'cuda':
dtype = torch.double
eps = 1e-6
elif device == 'mlu':
dtype = torch.float
eps = 1e-4
if _USING_PARROTS:
assert gradcheck(
func, (value.double(), shapes, level_start_index,
sampling_locations.double(), attention_weights.double(),
func, (value.to(dtype), shapes, level_start_index,
sampling_locations.to(dtype), attention_weights.to(dtype),
im2col_step),
no_grads=[shapes, level_start_index])
no_grads=[shapes, level_start_index],
eps=eps)
else:
assert gradcheck(func, (value.double(), shapes, level_start_index,
sampling_locations.double(),
attention_weights.double(), im2col_step))
assert gradcheck(
func, (value.to(dtype), shapes, level_start_index,
sampling_locations.to(dtype), attention_weights.to(dtype),
im2col_step),
eps=eps,
atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment