"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from collections import OrderedDict
import os

import numpy as np

from oneflow.test_utils.automated_test_util import *
from oneflow.test_utils.test_util import GenArgList

import oneflow as flow
import oneflow.unittest

test_conv2d_weight = np.array(
    [
        [
            [
                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
            ]
        ],
        [
            [
                [0.29670074582099915, 1.3111951351165771, 0.5035904049873352],
                [-1.1894450187683105, -0.5502137541770935, -1.591875672340393],
                [-1.1081947088241577, 0.07872020453214645, -0.9185634255409241],
            ]
        ],
        [
            [
                [-0.7457143664360046, -1.2080862522125244, 1.8140212297439575],
                [-1.5227429866790771, -2.515244960784912, -1.3549325466156006],
                [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],
            ]
        ],
    ]
)
test_conv2d_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                ],
                [
                    1.5580710172653198,
                    -0.5459445714950562,
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                ],
                [
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                    0.37723132967948914,
                ],
                [
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                ],
                [
                    1.830764889717102,
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                ],
            ]
        ]
    ]
)
test_conv2d_data_grad = np.array(
    [
        [
            [
                [
                    0.4095913469791412,
                    0.2847584038972855,
                    2.803684800863266,
                    2.3940934538841248,
                    2.5189263969659805,
                ],
                [
                    -1.9525419473648071,
                    -4.606781497597694,
                    -3.51521897315979,
                    -1.562677025794983,
                    1.0915625244379044,
                ],
                [
                    -2.1141327619552612,
                    -6.987950943410397,
                    -5.84306687861681,
                    -3.7289341166615486,
                    1.1448840647935867,
                ],
                [
                    -2.5237241089344025,
                    -7.272709347307682,
                    -8.646751679480076,
                    -6.123027570545673,
                    -1.3740423321723938,
                ],
                [
                    -0.1615908145904541,
                    -2.381169445812702,
                    -2.32784790545702,
                    -2.1662570908665657,
                    0.0533215403556824,
                ],
            ]
        ]
    ]
)
test_conv2d_weight_grad = np.array(
    [
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
    ]
)
test_conv2d_output = np.array(
    [
        [
            [
                [0.9699610471725464, -0.20758534967899323, 2.3857712745666504],
                [0.3666309118270874, 4.690882682800293, -8.203354835510254],
                [2.6072847843170166, -1.9033538103103638, 2.331153154373169],
            ],
            [
                [2.519343852996826, 2.3757898807525635, -1.6613528728485107],
                [0.5777544379234314, -3.5739502906799316, 5.349126815795898],
                [0.729295015335083, 1.5791023969650269, 3.7627718448638916],
            ],
            [
                [-0.27685487270355225, 6.446267127990723, -2.762883424758911],
                [-8.25644588470459, 9.616064071655273, 8.005367279052734],
                [-0.6944921016693115, 3.866114854812622, 4.788446426391602],
            ],
        ]
    ]
)
test_conv2d_with_bias_weight = np.array(
    [
        [
            [
                [1.8271433115005493, -1.0446699857711792, 1.0062190294265747],
                [0.5174201130867004, -0.806931734085083, 1.3769007921218872],
                [0.205885112285614, 0.9943519234657288, -0.23580588400363922],
            ]
        ],
        [
            [
                [0.29881811141967773, -1.9982075691223145, 0.3511354625225067],
                [-0.7644741535186768, 1.2594351768493652, -0.9629734754562378],
                [0.5080506205558777, 0.7561734318733215, 1.6839302778244019],
            ]
        ],
        [
            [
                [1.2573646306991577, 0.13123232126235962, 1.6403018236160278],
                [-1.2138012647628784, 2.399970531463623, -0.38509097695350647],
                [-0.9878040552139282, 0.9585888385772705, -1.4976465702056885],
            ]
        ],
    ]
)
test_conv2d_with_bias_bias = np.array(
    [0.6605162620544434, -0.18903568387031555, -0.27302607893943787]
)
test_conv2d_with_bias_data = np.array(
    [
        [
            [
                [
                    -0.47827261686325073,
                    -1.1739492416381836,
                    -0.7921845316886902,
                    0.9321041703224182,
                    -3.1557741165161133,
                ],
                [
                    2.1935296058654785,
                    -0.5385921001434326,
                    -0.8611332774162292,
                    -1.881519079208374,
                    -0.7205708026885986,
                ],
                [
                    -0.35601571202278137,
                    -0.15963983535766602,
                    1.797447681427002,
                    0.19594945013523102,
                    -1.7376397848129272,
                ],
                [
                    0.047347065061330795,
                    0.14580930769443512,
                    0.32604914903640747,
                    0.4578782916069031,
                    -0.8942581415176392,
                ],
                [
                    0.49383941292762756,
                    -0.9043426513671875,
                    -1.2140793800354004,
                    2.1564064025878906,
                    1.0938222408294678,
                ],
            ]
        ]
    ]
)
test_conv2d_with_bias_output = np.array(
    [
        [
            [
                [-0.05607491731643677, -0.185230553150177, -3.8808679580688477],
                [6.861937046051025, -2.3341472148895264, -0.5597308874130249],
                [1.8299254179000854, -2.770848274230957, 2.1958212852478027],
            ],
            [
                [2.9348952770233154, 4.117504119873047, -6.278541088104248],
                [0.2638452351093292, 3.998856782913208, 2.612290620803833],
                [-1.9891828298568726, -1.6476304531097412, 3.39066219329834],
            ],
            [
                [-8.44466781616211, 0.5747121572494507, -8.501373291015625],
                [-0.036642804741859436, -0.23458999395370483, -2.370849370956421],
                [2.8372013568878174, -2.987276077270508, 1.8382092714309692],
            ],
        ]
    ]
)
test_conv2d_group_weight = np.array(
    [
        [
            [
                [-0.7248556613922119, 1.1119636297225952, -0.47827261686325073],
                [-1.1739492416381836, -0.7921845316886902, 0.9321041703224182],
                [-3.1557741165161133, 2.1935296058654785, -0.5385921001434326],
            ]
        ],
        [
            [
                [-0.8611332774162292, -1.881519079208374, -0.7205708026885986],
                [-0.35601571202278137, -0.15963983535766602, 1.797447681427002],
                [0.19594945013523102, -1.7376397848129272, 0.047347065061330795],
            ]
        ],
    ]
)
test_conv2d_group_data_grad = np.array(
    [
        [
            [
                [
                    -0.7248556613922119,
                    0.3871079683303833,
                    -0.0911646485328674,
                    0.6336910128593445,
                    -0.4782726168632507,
                ],
                [
                    -1.8988049030303955,
                    -1.5790258049964905,
                    -1.125194251537323,
                    0.7736106514930725,
                    0.4538315534591675,
                ],
                [
                    -5.054579019546509,
                    -2.5412703156471252,
                    -2.6260308623313904,
                    2.4285481572151184,
                    -0.0847605466842651,
                ],
                [
                    -4.329723358154297,
                    -2.9283782839775085,
                    -2.534866213798523,
                    1.794857144355774,
                    0.3935120701789856,
                ],
                [
                    -3.1557741165161133,
                    -0.9622445106506348,
                    -1.5008366107940674,
                    1.654937505722046,
                    -0.5385921001434326,
                ],
            ],
            [
                [
                    -0.8611332774162292,
                    -2.7426523566246033,
                    -3.463223159313202,
                    -2.6020898818969727,
                    -0.7205708026885986,
                ],
                [
                    -1.2171489894390106,
                    -3.2583079040050507,
                    -2.1814310252666473,
                    -0.9642820358276367,
                    1.0768768787384033,
                ],
                [
                    -1.0211995393037796,
                    -4.799998238682747,
                    -3.6757742948830128,
                    -2.654574755579233,
                    1.1242239437997341,
                ],
                [
                    -0.1600662618875504,
                    -2.0573458820581436,
                    -0.2125511355698109,
                    -0.0524848736822605,
                    1.8447947464883327,
                ],
                [
                    0.195949450135231,
                    -1.5416903346776962,
                    -1.4943432696163654,
                    -1.6902927197515965,
                    0.0473470650613308,
                ],
            ],
        ]
    ]
)
test_conv2d_group_weight_grad = np.array(
    [
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
        [
            [
                [3.30740749835968, -0.7220746576786041, -3.660933956503868],
                [0.5273916646838188, -2.631059892475605, -7.6207195818424225],
                [-3.5466641262173653, -8.214546449482441, -11.031560003757477],
            ]
        ],
    ]
)
test_conv2d_group_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                ],
                [
                    1.5580710172653198,
                    -0.5459445714950562,
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                ],
                [
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                    0.37723132967948914,
                ],
                [
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                ],
                [
                    1.830764889717102,
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                ],
            ],
            [
                [
                    0.8586049675941467,
                    -0.2279418259859085,
                    0.2013147622346878,
                    0.35005471110343933,
                    0.5360521078109741,
                ],
                [
                    1.5194443464279175,
                    1.9040879011154175,
                    -1.5734431743621826,
                    -0.14007866382598877,
                    0.29670074582099915,
                ],
                [
                    1.3111951351165771,
                    0.5035904049873352,
                    -1.1894450187683105,
                    -0.5502137541770935,
                    -1.591875672340393,
                ],
                [
                    -1.1081947088241577,
                    0.07872020453214645,
                    -0.9185634255409241,
                    -0.7457143664360046,
                    -1.2080862522125244,
                ],
                [
                    1.8140212297439575,
                    -1.5227429866790771,
                    -2.515244960784912,
                    -1.3549325466156006,
                    -0.9574840068817139,
                ],
            ],
        ]
    ]
)
test_conv2d_group_output = np.array(
    [
        [
            [
                [-8.836943626403809, 3.2316627502441406, 6.994439601898193],
                [-0.8386597037315369, -9.857108116149902, 13.68197250366211],
                [-13.020713806152344, 7.310227870941162, -3.3760271072387695],
            ],
            [
                [-4.803101539611816, 1.026240587234497, 0.5452112555503845],
                [-6.839838027954102, 2.0195930004119873, 0.11328654736280441],
                [0.393694669008255, 4.987061023712158, 3.297354221343994],
            ],
        ]
    ]
)
test_conv2d_padding_weight = np.array(
    [
        [
            [
                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
            ]
        ]
    ]
)
test_conv2d_padding_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                ],
                [
                    1.5580710172653198,
                    -0.5459445714950562,
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                ],
                [
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                    0.37723132967948914,
                ],
                [
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                ],
                [
                    1.830764889717102,
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                ],
            ]
        ]
    ]
)
test_conv2d_padding_data_grad = np.array(
    [
        [
            [
                [
                    3.237529069185257,
                    3.237529069185257,
                    3.237529069185257,
                    3.237529069185257,
                    3.237529069185257,
                ],
                [
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                ],
                [
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                ],
                [
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                    3.428095132112503,
                ],
                [
                    2.596117228269577,
                    2.596117228269577,
                    2.596117228269577,
                    2.596117228269577,
                    2.596117228269577,
                ],
            ]
        ]
    ]
)
test_conv2d_padding_weight_grad = np.array(
    [
        [
            [
                [1.7594299167394638, 1.7594299167394638, 1.7594299167394638],
                [-0.6019042432308197, -0.6019042432308197, -0.6019042432308197],
                [-1.532561555504799, -1.532561555504799, -1.532561555504799],
            ]
        ]
    ]
)
test_conv2d_padding_output = np.array(
    [
        [
            [
                [
                    1.5489805936813354,
                    -1.0164761543273926,
                    5.277345657348633,
                    3.153532028198242,
                    -7.301508903503418,
                    -3.7565059661865234,
                    4.690962314605713,
                ],
                [
                    2.425799608230591,
                    -2.0592665672302246,
                    0.9699610471725464,
                    -0.20758534967899323,
                    2.3857712745666504,
                    1.1719579696655273,
                    0.6523551940917969,
                ],
                [
                    2.1625545024871826,
                    -1.3517316579818726,
                    0.3666309118270874,
                    4.690882682800293,
                    -8.203354835510254,
                    3.0248217582702637,
                    1.2624683380126953,
                ],
                [
                    0.6193475723266602,
                    -2.0285415649414062,
                    2.6072847843170166,
                    -1.9033538103103638,
                    2.331153154373169,
                    -3.998155355453491,
                    -1.0176407098770142,
                ],
                [
                    2.8643176555633545,
                    -0.7396122217178345,
                    -0.2253415733575821,
                    -2.846742630004883,
                    -4.961236476898193,
                    -0.1308247298002243,
                    -0.7344070672988892,
                ],
            ]
        ]
    ]
)
test_conv2d_stride_weight = np.array(
    [
        [
            [
                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
            ]
        ]
    ]
)
test_conv2d_stride_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                ],
                [
                    1.5580710172653198,
                    -0.5459445714950562,
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                ],
                [
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                    0.37723132967948914,
                ],
                [
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                ],
                [
                    1.830764889717102,
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                ],
            ]
        ]
    ]
)
test_conv2d_stride_data_grad = np.array(
    [
        [
            [
                [
                    0.5360521078109741,
                    1.5194443464279175,
                    0.3500547111034393,
                    0.5360521078109741,
                    1.5194443464279175,
                ],
                [
                    -1.8013850003480911,
                    0.061236098408699,
                    2.762692868709564,
                    -1.8013850003480911,
                    0.061236098408699,
                ],
                [
                    0.5360521078109741,
                    1.5194443464279175,
                    0.3500547111034393,
                    0.5360521078109741,
                    1.5194443464279175,
                ],
                [
                    -1.8013850003480911,
                    0.061236098408699,
                    2.762692868709564,
                    -1.8013850003480911,
                    0.061236098408699,
                ],
                [
                    0.5360521078109741,
                    1.5194443464279175,
                    0.3500547111034393,
                    0.5360521078109741,
                    1.5194443464279175,
                ],
            ]
        ]
    ]
)
test_conv2d_stride_weight_grad = np.array(
    [
        [
            [
                [-5.1135923862457275, 3.5859558284282684, 2.089697480201721],
                [-0.3276629596948624, 1.7587070614099503, -2.5950092673301697],
                [-5.1135923862457275, 3.5859558284282684, 2.089697480201721],
            ]
        ]
    ]
)
test_conv2d_stride_output = np.array(
    [
        [
            [
                [-1.0164761543273926, -7.301508903503418],
                [-1.3517316579818726, -8.203354835510254],
                [-0.7396122217178345, -4.961236476898193],
            ]
        ]
    ]
)
test_conv2d_kernel_weight = np.array(
    [
        [
            [
                [
                    -0.9574840068817139,
                    -0.7248556613922119,
                    1.1119636297225952,
                    -0.47827261686325073,
                    -1.1739492416381836,
                ],
                [
                    -0.7921845316886902,
                    0.9321041703224182,
                    -3.1557741165161133,
                    2.1935296058654785,
                    -0.5385921001434326,
                ],
                [
                    -0.8611332774162292,
                    -1.881519079208374,
                    -0.7205708026885986,
                    -0.35601571202278137,
                    -0.15963983535766602,
                ],
            ]
        ]
    ]
)
test_conv2d_kernel_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                    1.5580710172653198,
                    -0.5459445714950562,
                ],
                [
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                ],
                [
                    0.37723132967948914,
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                    1.830764889717102,
                ],
                [
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                    0.8586049675941467,
                    -0.2279418259859085,
                    0.2013147622346878,
                ],
                [
                    0.35005471110343933,
                    0.5360521078109741,
                    1.5194443464279175,
                    1.9040879011154175,
                    -1.5734431743621826,
                    -0.14007866382598877,
                    0.29670074582099915,
                ],
                [
                    1.3111951351165771,
                    0.5035904049873352,
                    -1.1894450187683105,
                    -0.5502137541770935,
                    -1.591875672340393,
                    -1.1081947088241577,
                    0.07872020453214645,
                ],
                [
                    -0.9185634255409241,
                    -0.7457143664360046,
                    -1.2080862522125244,
                    1.8140212297439575,
                    -1.5227429866790771,
                    -2.515244960784912,
                    -1.3549325466156006,
                ],
            ]
        ]
    ]
)
test_conv2d_kernel_data_grad = np.array(
    [
        [
            [
                [
                    -0.9574840068817139,
                    -1.6823396682739258,
                    -0.5703760385513306,
                    -0.0911646485328674,
                    -0.5402582287788391,
                    -1.6522218585014343,
                    -1.1739492416381836,
                ],
                [
                    -1.749668538570404,
                    -1.5424200296401978,
                    -3.586230516433716,
                    -0.121304988861084,
                    -2.0410948395729065,
                    0.0027156472206116,
                    -1.7125413417816162,
                ],
                [
                    -2.6108018159866333,
                    -4.285072386264801,
                    -7.049453675746918,
                    -3.079410582780838,
                    -3.2773211896419525,
                    -0.5129399001598358,
                    -1.8721811771392822,
                ],
                [
                    -2.6108018159866333,
                    -4.285072386264801,
                    -7.049453675746918,
                    -3.079410582780838,
                    -3.2773211896419525,
                    -0.5129399001598358,
                    -1.8721811771392822,
                ],
                [
                    -2.6108018159866333,
                    -4.285072386264801,
                    -7.049453675746918,
                    -3.079410582780838,
                    -3.2773211896419525,
                    -0.5129399001598358,
                    -1.8721811771392822,
                ],
                [
                    -1.6533178091049194,
                    -2.6027327179908752,
                    -6.479077637195587,
                    -2.9882459342479706,
                    -2.7370629608631134,
                    1.1392819583415985,
                    -0.6982319355010986,
                ],
                [
                    -0.8611332774162292,
                    -2.7426523566246033,
                    -3.463223159313202,
                    -2.958105593919754,
                    -1.236226350069046,
                    -0.5156555473804474,
                    -0.159639835357666,
                ],
            ]
        ]
    ]
)
test_conv2d_kernel_weight_grad = np.array(
    [
        [
            [
                [
                    2.974529668688774,
                    4.548736393451691,
                    1.1672898679971695,
                    -1.499158263206482,
                    0.1862268149852753,
                ],
                [
                    1.6534235626459122,
                    2.3762744814157486,
                    -1.448018729686737,
                    -5.2917241007089615,
                    -2.278435029089451,
                ],
                [
                    -2.083257421851158,
                    -2.23808591067791,
                    -5.749193429946899,
                    -7.540486767888069,
                    -6.306201495230198,
                ],
            ]
        ]
    ]
)
test_conv2d_kernel_output = np.array(
    [
        [
            [
                [-3.5647754669189453, -4.234736919403076, 1.4046944379806519],
                [-0.6964312791824341, 16.42838478088379, -9.649789810180664],
                [4.312150478363037, -6.283960819244385, -4.8443922996521],
                [-2.772286891937256, -4.483709812164307, 12.315184593200684],
                [7.39893913269043, 1.305102825164795, -2.049992561340332],
            ]
        ]
    ]
)
test_conv2d_dilation_weight = np.array(
    [
        [
            [
                [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],
                [-0.47827261686325073, -1.1739492416381836, -0.7921845316886902],
                [0.9321041703224182, -3.1557741165161133, 2.1935296058654785],
            ]
        ]
    ]
)
test_conv2d_dilation_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                    1.5580710172653198,
                    -0.5459445714950562,
                ],
                [
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                ],
                [
                    0.37723132967948914,
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                    1.830764889717102,
                ],
                [
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                    0.8586049675941467,
                    -0.2279418259859085,
                    0.2013147622346878,
                ],
                [
                    0.35005471110343933,
                    0.5360521078109741,
                    1.5194443464279175,
                    1.9040879011154175,
                    -1.5734431743621826,
                    -0.14007866382598877,
                    0.29670074582099915,
                ],
                [
                    1.3111951351165771,
                    0.5035904049873352,
                    -1.1894450187683105,
                    -0.5502137541770935,
                    -1.591875672340393,
                    -1.1081947088241577,
                    0.07872020453214645,
                ],
                [
                    -0.9185634255409241,
                    -0.7457143664360046,
                    -1.2080862522125244,
                    1.8140212297439575,
                    -1.5227429866790771,
                    -2.515244960784912,
                    -1.3549325466156006,
                ],
            ]
        ]
    ]
)
test_conv2d_dilation_data_grad = np.array(
    [
        [
            [
                [
                    -0.9574840068817139,
                    0.0,
                    0.0,
                    -0.7248556613922119,
                    0.0,
                    0.0,
                    1.1119636297225952,
                ],
                [
                    -0.9574840068817139,
                    0.0,
                    0.0,
                    -0.7248556613922119,
                    0.0,
                    0.0,
                    1.1119636297225952,
                ],
                [
                    -1.4357566237449646,
                    0.0,
                    0.0,
                    -1.8988049030303955,
                    0.0,
                    0.0,
                    0.319779098033905,
                ],
                [
                    -0.4782726168632507,
                    0.0,
                    0.0,
                    -1.1739492416381836,
                    0.0,
                    0.0,
                    -0.7921845316886902,
                ],
                [
                    0.4538315534591675,
                    0.0,
                    0.0,
                    -4.329723358154297,
                    0.0,
                    0.0,
                    1.4013450741767883,
                ],
                [
                    0.9321041703224182,
                    0.0,
                    0.0,
                    -3.1557741165161133,
                    0.0,
                    0.0,
                    2.1935296058654785,
                ],
                [
                    0.9321041703224182,
                    0.0,
                    0.0,
                    -3.1557741165161133,
                    0.0,
                    0.0,
                    2.1935296058654785,
                ],
            ]
        ]
    ]
)
test_conv2d_dilation_weight_grad = np.array(
    [
        [
            [
                [-0.8153198063373566, -1.3503028601408005, 1.1495047211647034],
                [-0.4195204377174377, -1.4455246925354004, 2.328780397772789],
                [0.7426864206790924, 3.1678953766822815, -0.979511596262455],
            ]
        ]
    ]
)
test_conv2d_dilation_output = np.array(
    [[[[-5.2563982009887695], [5.410353183746338], [-8.517012596130371]]]]
)


def _test_conv2d(
    test_case, conv, data, weight, output, bias=None, device="cuda",
):
    to_device = flow.device(device)
    x = flow.tensor(data, dtype=flow.float32, device=to_device)
    conv.weight = flow.nn.Parameter(flow.Tensor(weight))
    if bias is not None:
        conv.bias = flow.nn.Parameter(flow.Tensor(bias))
    conv.to(to_device)
    of_out = conv(x)
    test_case.assertTrue(np.allclose(of_out.numpy(), output, rtol=1e-4, atol=1e-8))


def _test_conv2d_backward(
    test_case, conv, data, weight, data_grad, weight_grad, bias=None, device="cuda",
):
    to_device = flow.device(device)
    x = flow.tensor(data, dtype=flow.float32, device=to_device, requires_grad=True)
    conv.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)
    if bias is not None:
        conv.bias = flow.nn.Parameter(flow.Tensor(bias))
    conv.to(to_device)
    of_out = conv(x)
    of_out.sum().backward()
    test_case.assertTrue(np.allclose(x.grad.numpy(), data_grad, rtol=1e-4, atol=1e-8))
    test_case.assertTrue(
        np.allclose(conv.weight.grad.numpy(), weight_grad, rtol=1e-4, atol=1e-8)
    )


def _test_conv2d_large_in_channel(test_case, device):
    np_arr = np.array(
        [
            [
                [
                    [
                        0.6206631238581714,
                        -1.1225329393404626,
                        0.8407155480700242,
                        -0.6845162855236345,
                    ],
                    [
                        -0.5186484633906412,
                        0.10420735184519186,
                        -0.1711568947473012,
                        0.5168640476046483,
                    ],
                    [
                        -0.12429464919764661,
                        0.050277779246134253,
                        -1.0144501797426606,
                        -2.184600444658526,
                    ],
                    [
                        0.28918126931309923,
                        -0.822872663244595,
                        0.44019150436683663,
                        -1.0247720130825562,
                    ],
                ],
                [
                    [
                        0.7786504412818226,
                        -0.7501839068078657,
                        -0.8187283189941765,
                        -1.1116653569170698,
                    ],
                    [
                        0.18085524152316743,
                        -1.3461349607476678,
                        1.142505437476448,
                        -0.000649619704040145,
                    ],
                    [
                        0.03160672782674317,
                        -0.006318157449953413,
                        1.2218487782604377,
                        0.15903027907930234,
                    ],
                    [
                        1.5857011815642381,
                        0.6656477116332891,
                        -0.04036621813223574,
                        -0.3427168687988546,
                    ],
                ],
                [
                    [
                        -1.1774346070102524,
                        1.6195241269303395,
                        -0.36185552303441965,
                        -1.1382193113192487,
                    ],
                    [
                        0.08061907334568702,
                        1.5025447613238763,
                        -1.1591348706634745,
                        1.6449050139676873,
                    ],
                    [
                        1.1539915649822392,
                        -2.414624939646017,
                        0.3056063774849572,
                        1.1920089257083162,
                    ],
                    [
                        0.7623012858982319,
                        -0.01685314742940813,
                        -1.096666898224702,
                        -0.4406476137098582,
                    ],
                ],
                [
                    [
                        0.9383797282214235,
                        -1.1075876842796508,
                        -0.4420913825139058,
                        -1.0736097610655628,
                    ],
                    [
                        -0.3101376466546291,
                        1.6578227745160954,
                        -0.6225454278031398,
                        0.6831188620748697,
                    ],
                    [
                        0.00743800968372913,
                        -0.8089158949698473,
                        2.08084287836801,
                        0.721204366332351,
                    ],
                    [
                        0.5694701823297723,
                        0.031519314469744895,
                        -0.5041680957766629,
                        -0.4738588233094669,
                    ],
                ],
            ]
        ]
    )
    input = flow.tensor(
        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
    )
    weight = np.array(
        [
            [
                [
                    [0.06456436216831207, -0.10852358490228653, -0.21638715267181396],
                    [-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],
                    [0.05026858672499657, 0.10818571597337723, 0.02056501805782318],
                ],
                [
                    [0.205095112323761, 0.1488947868347168, -0.2344113141298294],
                    [0.1684819906949997, -0.21986986696720123, 0.1082606166601181],
                    [-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],
                ],
            ],
            [
                [
                    [-0.09441672265529633, -0.03644559532403946, -0.22235223650932312],
                    [-0.1771145612001419, 0.08043312281370163, 0.06938580423593521],
                    [0.054393064230680466, -0.05483492836356163, 0.23438701033592224],
                ],
                [
                    [0.22666795551776886, 0.0874653309583664, 0.07092718034982681],
                    [0.08883464336395264, -0.052362944930791855, -0.1720171570777893],
                    [0.10441060364246368, 0.011952142231166363, -0.0894528403878212],
                ],
            ],
        ]
    )
    m = flow.nn.Conv2d(4, 2, 3, groups=2, bias=False)
    m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)
    m = m.to(device)
    output = m(input)
    np_out = [
        [
            [
                [0.7666134238243103, -0.3961866497993469],
                [-0.656266987323761, -1.1613956689834595],
            ],
            [
                [0.3077264130115509, -0.42817503213882446],
                [-0.5761325359344482, 0.1300736665725708],
            ],
        ]
    ]
    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-3, 1e-3))
    output = output.sum()
    output.backward()
    np_grad = [
        [
            [
                [
                    0.06456436216831207,
                    -0.04395922273397446,
                    -0.3249107301235199,
                    -0.21638715267181396,
                ],
                [
                    -0.16334669291973114,
                    -0.12419328093528748,
                    0.017341122031211853,
                    -0.021812304854393005,
                ],
                [
                    -0.17764246463775635,
                    0.07822024822235107,
                    0.47100257873535156,
                    0.21513986587524414,
                ],
                [
                    0.05026858672499657,
                    0.1584542989730835,
                    0.128750741481781,
                    0.02056501805782318,
                ],
            ],
            [
                [
                    0.205095112323761,
                    0.3539898991584778,
                    -0.08551652729511261,
                    -0.2344113141298294,
                ],
                [
                    0.3735771179199219,
                    0.30260205268859863,
                    -0.19712577760219574,
                    -0.1261506974697113,
                ],
                [
                    0.015584588050842285,
                    -0.03308109939098358,
                    0.07913993299007416,
                    0.12780562043190002,
                ],
                [
                    -0.1528974026441574,
                    0.018306776881217957,
                    0.1907491832971573,
                    0.01954500749707222,
                ],
            ],
            [
                [
                    -0.09441672265529633,
                    -0.13086232542991638,
                    -0.258797824382782,
                    -0.22235223650932312,
                ],
                [
                    -0.27153128385543823,
                    -0.22754377126693726,
                    -0.10897888988256454,
                    -0.1529664397239685,
                ],
                [
                    -0.12272149324417114,
                    -0.09712330251932144,
                    0.32937100529670715,
                    0.30377280712127686,
                ],
                [
                    0.054393064230680466,
                    -0.00044186413288116455,
                    0.1795520782470703,
                    0.23438701033592224,
                ],
            ],
            [
                [
                    0.22666795551776886,
                    0.31413328647613525,
                    0.1583925187587738,
                    0.07092718034982681,
                ],
                [
                    0.3155025839805603,
                    0.35060498118400574,
                    -0.06598758697509766,
                    -0.1010899767279625,
                ],
                [
                    0.19324524700641632,
                    0.1528344452381134,
                    -0.301880806684494,
                    -0.2614699900150299,
                ],
                [
                    0.10441060364246368,
                    0.11636274307966232,
                    -0.07750070095062256,
                    -0.0894528403878212,
                ],
            ],
        ]
    ]
    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-3, 1e-3))


def _test_conv2d_large_out_channel(test_case, device):
    np_arr = np.array(
        [
            [
                [
                    [0.56573248, -0.19689320, -0.67875558, 0.34328273, 0.31964567],
                    [-1.33715475, 0.33422229, -1.27643383, 0.37904647, 0.35891593],
                    [0.84579802, 2.12729621, -0.51423287, 0.61297560, -1.31156564],
                    [-0.71047139, 1.02679253, -0.76686019, -0.72969633, 0.73425150],
                    [-0.13592879, -1.03207183, -0.22554775, 0.74148071, 0.96601510],
                ],
                [
                    [0.51595992, 0.49624804, 0.91145641, 0.49247262, 0.41002217],
                    [-1.08001196, 1.55497086, -0.81963140, -0.45511565, -0.60269165],
                    [0.05563145, -0.94318372, -1.17058158, -0.73568577, 0.57810956],
                    [-0.40260276, -0.10309298, 1.12378800, -0.23510537, -0.73893374],
                    [-0.52712536, -0.00717016, -1.85051966, -1.50790560, 1.38335907],
                ],
            ]
        ]
    )
    input = flow.tensor(
        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
    )
    weight = np.array(
        [
            [
                [
                    [-0.19489679, -0.32377058, 0.21736273],
                    [0.04095296, -0.21552679, -0.14626531],
                    [-0.19359522, -0.00742865, -0.19832158],
                ]
            ],
            [
                [
                    [0.29926914, 0.00931164, 0.26197660],
                    [0.27611443, -0.15439281, -0.19027126],
                    [-0.28909120, 0.30367029, -0.05168664],
                ]
            ],
            [
                [
                    [-0.03155736, 0.17610769, 0.22111714],
                    [0.22790670, -0.32897446, -0.03260243],
                    [-0.10274851, -0.06903386, -0.19438276],
                ]
            ],
            [
                [
                    [-0.24573688, -0.06723209, -0.21363299],
                    [-0.02136187, -0.24994437, -0.18691199],
                    [0.12189507, 0.29469389, 0.03398871],
                ]
            ],
        ]
    )
    m = flow.nn.Conv2d(2, 4, 3, groups=2, bias=False)
    m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)
    m = m.to(device)
    output = m(input)
    np_out = np.array(
        [
            [
                [
                    [-0.21170563, 0.03652292, 0.25926736],
                    [-0.19168918, 0.49044561, 0.25099146],
                    [-1.02489340, 0.25361472, -0.51828313],
                ],
                [
                    [0.23977707, -0.56090075, -0.19285655],
                    [-0.17167747, 0.24558367, -0.30935860],
                    [-0.33303234, 1.52472734, -0.49013454],
                ],
                [
                    [-0.17137986, 1.21333742, 0.18988736],
                    [0.31785482, -0.12121570, -0.18676008],
                    [-0.10680684, -0.30298883, 0.41809759],
                ],
                [
                    [-0.87821335, -0.51665992, -0.44061098],
                    [0.74804580, 0.53107250, 0.50418228],
                    [-0.00512899, -0.36455840, -0.23643512],
                ],
            ]
        ]
    )
    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-3, 1e-3))
    output = output.sum()
    output.backward()
    np_grad = np.array(
        [
            [
                [
                    [0.10437235, -0.21008658, 0.26925275, 0.16488039, 0.47933933],
                    [0.42143974, -0.26293880, -0.12013602, -0.54157579, 0.14280275],
                    [-0.06124666, -0.44938356, -0.55658901, -0.49534237, -0.10720548],
                    [-0.16561902, -0.23929697, -0.82584178, -0.66022277, -0.58654481],
                    [-0.48268640, -0.18644476, -0.43645298, 0.04623342, -0.25000823],
                ],
                [
                    [-0.27729425, -0.16841865, -0.16093449, 0.11635975, 0.00748415],
                    [-0.07074942, -0.54079264, -0.75282294, -0.68207347, -0.21203026],
                    [-0.05160286, -0.29598606, -0.66841042, -0.61680746, -0.37242430],
                    [0.22569139, -0.12756741, -0.50747585, -0.73316729, -0.37990844],
                    [0.01914656, 0.24480659, 0.08441254, 0.06526598, -0.16039404],
                ],
            ]
        ]
    )
    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-3, 1e-3))


@flow.unittest.skip_unless_1n1d()
class TestConv2d(flow.unittest.TestCase):
    def test_conv2d_default_init(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 3), bias=True).to(flow.device(device))
            test_case.assertTrue(
                not np.allclose(
                    conv.weight.numpy(), np.zeros((1, 1, 3, 3)), rtol=1e-9, atol=1e-10
                )
            )
            test_case.assertTrue(
                not np.allclose(
                    conv.bias.numpy(), np.zeros((1,)), rtol=1e-9, atol=1e-10
                )
            )

    @autotest(n=3)
    def test_nn_functional_conv2d(test_case):
        device = random_device()
        img = torch.ones((1, 3, 224, 224), requires_grad=True).to(device)
        kernel = torch.ones((3, 1, 3, 3), requires_grad=True).to(device)
        y = torch.nn.functional.conv2d(img, kernel, groups=3)
        return y

    def test_conv2d(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device(device))
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_data,
                test_conv2d_weight,
                test_conv2d_output,
                device=device,
            )

    def test_conv2d_backward(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device(device))
            _test_conv2d_backward(
                test_case,
                conv,
                test_conv2d_data,
                test_conv2d_weight,
                test_conv2d_data_grad,
                test_conv2d_weight_grad,
                device=device,
            )

    # bias grad not yet supported
    def test_conv2d_with_bias(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 3, (3, 3), bias=True).to(flow.device(device))
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_with_bias_data,
                test_conv2d_with_bias_weight,
                test_conv2d_with_bias_output,
                bias=test_conv2d_with_bias_bias,
                device=device,
            )

    def test_conv2d_group(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False).to(
                flow.device(device)
            )
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_group_data,
                test_conv2d_group_weight,
                test_conv2d_group_output,
                device=device,
            )

    def test_conv2d_group_backward(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(2, 2, (3, 3), groups=2, bias=False).to(
                flow.device(device)
            )
            _test_conv2d_backward(
                test_case,
                conv,
                test_conv2d_group_data,
                test_conv2d_group_weight,
                test_conv2d_group_data_grad,
                test_conv2d_group_weight_grad,
                device=device,
            )

    def test_conv2d_padding(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False).to(
                flow.device(device)
            )
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_padding_data,
                test_conv2d_padding_weight,
                test_conv2d_padding_output,
                device=device,
            )

    def test_conv2d_padding_backward(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 3), padding=(1, 2), bias=False).to(
                flow.device(device)
            )
            _test_conv2d_backward(
                test_case,
                conv,
                test_conv2d_padding_data,
                test_conv2d_padding_weight,
                test_conv2d_padding_data_grad,
                test_conv2d_padding_weight_grad,
                device=device,
            )

    def test_conv2d_stride(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(
                1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False
            ).to(flow.device(device))
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_stride_data,
                test_conv2d_stride_weight,
                test_conv2d_stride_output,
                device=device,
            )

    def test_conv2d_stride_backward(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(
                1, 1, (3, 3), padding=(1, 1), stride=(2, 3), bias=False
            ).to(flow.device(device))
            _test_conv2d_backward(
                test_case,
                conv,
                test_conv2d_stride_data,
                test_conv2d_stride_weight,
                test_conv2d_stride_data_grad,
                test_conv2d_stride_weight_grad,
                device=device,
            )

    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
    def test_conv2d_kernel(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False).to(flow.device(device))
            conv.to(flow.device("cuda"))
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_kernel_data,
                test_conv2d_kernel_weight,
                test_conv2d_kernel_output,
                device=device,
            )

    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
    def test_conv2d_kernel_backward(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 5), bias=False).to(flow.device(device))
            conv.to(flow.device("cuda"))
            _test_conv2d_backward(
                test_case,
                conv,
                test_conv2d_kernel_data,
                test_conv2d_kernel_weight,
                test_conv2d_kernel_data_grad,
                test_conv2d_kernel_weight_grad,
                device=device,
            )

    def test_conv2d_dilation(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False).to(
                flow.device(device)
            )
            _test_conv2d(
                test_case,
                conv,
                test_conv2d_dilation_data,
                test_conv2d_dilation_weight,
                test_conv2d_dilation_output,
                device=device,
            )

    def test_conv2d_dilation_backward(test_case):
        arg_dict = OrderedDict()
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            device = arg[0]
            conv = flow.nn.Conv2d(1, 1, (3, 3), dilation=(2, 3), bias=False).to(
                flow.device(device)
            )
            _test_conv2d_backward(
                test_case,
                conv,
                test_conv2d_dilation_data,
                test_conv2d_dilation_weight,
                test_conv2d_dilation_data_grad,
                test_conv2d_dilation_weight_grad,
                device=device,
            )

    def test_large_in_channel_group_conv(test_case):
        arg_dict = OrderedDict()
        arg_dict["test_fun"] = [
            _test_conv2d_large_in_channel,
        ]
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            arg[0](test_case, *arg[1:])

    def test_large_out_channel_group_conv(test_case):
        arg_dict = OrderedDict()
        arg_dict["test_fun"] = [
            _test_conv2d_large_out_channel,
        ]
        arg_dict["device"] = ["cuda", "cpu"]
        for arg in GenArgList(arg_dict):
            arg[0](test_case, *arg[1:])

    @autotest(n=5)
    def test_conv2d_with_random_data(test_case):
        channels = random(1, 6)
        m = torch.nn.Conv2d(
            in_channels=channels,
            out_channels=random(1, 20),
            kernel_size=random(1, 4),
            stride=random() | nothing(),
            padding=random(1, 3).to(int) | nothing(),
            dilation=random(1, 5) | nothing(),
            groups=random(1, 5) | nothing(),
            padding_mode=constant("zeros") | nothing(),
        )
        m.train(random())
        device = random_device()
        m.to(device)
        x = random_tensor(ndim=4, dim1=channels).to(device)
        y = m(x)
        return y

    @autotest(n=5, check_graph=False)
    def test_conv2d_0size_with_random_data(test_case):
        channels = random(1, 6)
        m = torch.nn.Conv2d(
            in_channels=channels,
            out_channels=random(1, 20),
            kernel_size=random(1, 4),
            stride=random() | nothing(),
            padding=random(1, 3).to(int) | nothing(),
            dilation=random(1, 5) | nothing(),
            groups=random(1, 5) | nothing(),
            padding_mode=constant("zeros") | nothing(),
        )
        m.train(random())
        device = random_device()
        m.to(device)
        x = random_tensor(ndim=4, dim0=0, dim1=channels).to(device)
        y = m(x)
        return y

    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
    @autotest(n=5, check_allclose=False)
    def test_conv2d_group_with_random_data(test_case):
        channels = 720  # lcm(1, 2, 3, 4, 5, 6)
        m = torch.nn.Conv2d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=random(1, 4),
            stride=random() | nothing(),
            padding=random(1, 3).to(int) | nothing(),
            dilation=random(1, 5) | nothing(),
            groups=random(1, 7),
            padding_mode=constant("zeros") | nothing(),
        )
        m.train(random())

        device = random_device()
        m.to(device)
        m.pytorch.to("cuda")
        x = random_tensor(ndim=4, dim1=channels).to(device)
        x.pytorch = x.pytorch.to("cuda")
        y = m(x)
        return y

    # @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
    # def test_conv2d_NHWC_with_random_data(test_case):
    #     in_channels = np.random.randint(6, 33)
    #     out_channels = np.random.randint(32, 66)
    #     kernel_size = np.random.randint(1, 5)
    #     stride = np.random.randint(1, 2)
    #     padding = np.random.randint(1, 3)
    #     dilation = np.random.randint(1, 3)
    #     spatial = np.random.randint(6, 64)

    #     np_x = np.random.randn(4, in_channels, spatial, spatial).astype(np.float32)
    #     np_weight = np.random.randn(
    #         out_channels, in_channels, kernel_size, kernel_size
    #     ).astype(np.float32)
    #     np_bias = np.random.randn(out_channels).astype(np.float32)

    #     flow_nchw_input = flow.tensor(
    #         np_x, device="cuda", dtype=flow.float32, requires_grad=True
    #     )
    #     flow_nchw_weights = flow.nn.Parameter(
    #         flow.tensor(
    #             np_weight, device="cuda", dtype=flow.float32, requires_grad=True
    #         )
    #     )
    #     flow_nchw_bias = flow.nn.Parameter(
    #         flow.tensor(np_bias, device="cuda", dtype=flow.float32, requires_grad=True)
    #     )

    #     flow_nchw_conv = flow.nn.Conv2d(
    #         in_channels=in_channels,
    #         out_channels=out_channels,
    #         kernel_size=kernel_size,
    #         stride=stride,
    #         padding=padding,
    #         dilation=dilation,
    #     ).to("cuda")
    #     flow_nchw_conv.weight = flow_nchw_weights
    #     flow_nchw_conv.bias = flow_nchw_bias

    #     flow_nchw_out = flow_nchw_conv(flow_nchw_input)

    #     os.environ["ONEFLOW_ENABLE_NHWC"] = "1"
    #     flow_nhwc_input = flow.tensor(
    #         np_x, device="cuda", dtype=flow.float32, requires_grad=True
    #     )
    #     flow_nhwc_permuted_input = flow.permute(flow_nhwc_input, (0, 2, 3, 1))
    #     flow_nhwc_weights = flow.tensor(
    #         np_weight, device="cuda", dtype=flow.float32, requires_grad=True
    #     )
    #     flow_nhwc_permuted_weights = flow.nn.Parameter(
    #         flow.permute(flow_nhwc_weights, (0, 2, 3, 1))
    #     )
    #     flow_nhwc_bias = flow.nn.Parameter(
    #         flow.tensor(np_bias, device="cuda", dtype=flow.float32, requires_grad=True)
    #     )

    #     flow_nhwc_conv = flow.nn.Conv2d(
    #         in_channels=in_channels,
    #         out_channels=out_channels,
    #         kernel_size=kernel_size,
    #         stride=stride,
    #         padding=padding,
    #         dilation=dilation,
    #     ).to("cuda")
    #     flow_nhwc_conv.weight = flow_nhwc_permuted_weights
    #     flow_nhwc_conv.bias = flow_nhwc_bias

    #     flow_nhwc_out = flow_nhwc_conv(flow_nhwc_permuted_input)
    #     flow_nhwc_permuted_out = flow.permute(flow_nhwc_out, (0, 3, 1, 2))

    #     test_case.assertTrue(
    #         np.allclose(
    #             flow_nchw_out.numpy(),
    #             flow_nhwc_permuted_out.numpy(),
    #             rtol=1e-4,
    #             atol=1e-4,
    #         )
    #     )

    #     total_out = flow_nchw_out + flow_nhwc_permuted_out

    #     total_out = total_out.sum()
    #     total_out.backward()
    #     test_case.assertTrue(
    #         np.allclose(
    #             flow_nchw_weights.grad.numpy(),
    #             np.transpose(flow_nhwc_permuted_weights.grad.numpy(), (0, 3, 1, 2)),
    #             rtol=1e-4,
    #             atol=1e-4,
    #         )
    #     )
    #     test_case.assertTrue(
    #         np.allclose(
    #             flow_nchw_input.grad.numpy(),
    #             flow_nhwc_input.grad.numpy(),
    #             rtol=1e-4,
    #             atol=1e-4,
    #         )
    #     )
    #     os.environ["ONEFLOW_ENABLE_NHWC"] = "0"

    # @profile(torch.nn.functional.conv2d)
    # def profile_conv2d(test_case):
    #     input = torch.ones(8, 128, 28, 28)
    #     weight = torch.ones(128, 128, 3, 3)
    #     bias = torch.ones(128)
    #     torch.nn.functional.conv2d(input, weight, padding=1)
    #     torch.nn.functional.conv2d(input, weight, padding=1, stride=2)
    #     torch.nn.functional.conv2d(input, weight, bias=bias, padding=1)
    #     torch.nn.functional.conv2d(input, weight, bias=bias, padding=1, stride=2)


if __name__ == "__main__":
    unittest.main()
