import unittest
from collections import OrderedDict

import numpy as np
from oneflow.test_utils.test_util import GenArgList
from oneflow.test_utils.automated_test_util import *

import oneflow as flow
import oneflow.nn as nn
import oneflow.unittest

np_arr = np.array([[[1.28795946, -0.2921792, 0.20338029, 0.78604293, -1.89607573]]])
input = flow.tensor(
    np_arr, dtype=flow.float32, device=flow.device("cuda"), requires_grad=True
)
weight = np.array(
    [
        [[0.10197904, 0.3372305, -0.25743008]],
        [[0.27720425, -0.52435774, -0.38381988]],
        [[0.56016803, -0.10063095, -0.10760903]],
    ]
)
m = nn.Conv1d(1, 3, 3, stride=1, bias=False)
m.weight = flow.nn.Parameter(flow.Tensor(weight))
m = m.to("cuda")
output = m(input)
np_out = np.array(
    [
        [
            [-0.01954307, -0.16356121, 0.77392507],
            [0.43217283, -0.48933625, 0.37196174],
            [0.72899038, -0.2687211, 0.23886177],
        ]
    ]
)
if np.allclose(output.numpy(), np_out, 1e-06, 1e-06):
    print("conv1d Passed")
output = output.sum()
output.backward()
np_grad = np.array(
    [[[0.93935132, 0.65159315, -0.09726584, -1.03661716, -0.74885899]]]
)
if np.allclose(input.grad.numpy(), np_grad, 1e-06, 1e-06):
    print("conv1d_back Passed")



test_conv2d_weight = np.array(
    [
        [
            [
                [0.8586049675941467, -0.2279418259859085, 0.2013147622346878],
                [0.35005471110343933, 0.5360521078109741, 1.5194443464279175],
                [1.9040879011154175, -1.5734431743621826, -0.14007866382598877],
            ]
        ],
        [
            [
                [0.29670074582099915, 1.3111951351165771, 0.5035904049873352],
                [-1.1894450187683105, -0.5502137541770935, -1.591875672340393],
                [-1.1081947088241577, 0.07872020453214645, -0.9185634255409241],
            ]
        ],
        [
            [
                [-0.7457143664360046, -1.2080862522125244, 1.8140212297439575],
                [-1.5227429866790771, -2.515244960784912, -1.3549325466156006],
                [-0.9574840068817139, -0.7248556613922119, 1.1119636297225952],
            ]
        ],
    ]
)
test_conv2d_data = np.array(
    [
        [
            [
                [
                    1.1630785465240479,
                    0.4838046133518219,
                    0.299563467502594,
                    0.15302546322345734,
                    -1.168814778327942,
                ],
                [
                    1.5580710172653198,
                    -0.5459445714950562,
                    -2.3556296825408936,
                    0.5414402484893799,
                    2.678506374359131,
                ],
                [
                    1.2546343803405762,
                    -0.5487740635871887,
                    -0.6810643672943115,
                    -0.13531559705734253,
                    0.37723132967948914,
                ],
                [
                    0.41016456484794617,
                    0.5712682008743286,
                    -2.757962703704834,
                    1.0762799978256226,
                    -0.6141325235366821,
                ],
                [
                    1.830764889717102,
                    -1.1468064785003662,
                    0.053837940096855164,
                    -2.5074806213378906,
                    -0.5916498899459839,
                ],
            ]
        ]
    ]
)
test_conv2d_data_grad = np.array(
    [
        [
            [
                [
                    0.4095913469791412,
                    0.2847584038972855,
                    2.803684800863266,
                    2.3940934538841248,
                    2.5189263969659805,
                ],
                [
                    -1.9525419473648071,
                    -4.606781497597694,
                    -3.51521897315979,
                    -1.562677025794983,
                    1.0915625244379044,
                ],
                [
                    -2.1141327619552612,
                    -6.987950943410397,
                    -5.84306687861681,
                    -3.7289341166615486,
                    1.1448840647935867,
                ],
                [
                    -2.5237241089344025,
                    -7.272709347307682,
                    -8.646751679480076,
                    -6.123027570545673,
                    -1.3740423321723938,
                ],
                [
                    -0.1615908145904541,
                    -2.381169445812702,
                    -2.32784790545702,
                    -2.1662570908665657,
                    0.0533215403556824,
                ],
            ]
        ]
    ]
)
test_conv2d_weight_grad = np.array(
    [
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
        [
            [
                [0.6277393400669098, -2.7888944894075394, -0.2910575419664383],
                [-3.095237225294113, -4.835702538490295, -1.8706469237804413],
                [-1.0139376372098923, -6.076017692685127, -5.780256435275078],
            ]
        ],
    ]
)
test_conv2d_output = np.array(
    [
        [
            [
                [0.9699610471725464, -0.20758534967899323, 2.3857712745666504],
                [0.3666309118270874, 4.690882682800293, -8.203354835510254],
                [2.6072847843170166, -1.9033538103103638, 2.331153154373169],
            ],
            [
                [2.519343852996826, 2.3757898807525635, -1.6613528728485107],
                [0.5777544379234314, -3.5739502906799316, 5.349126815795898],
                [0.729295015335083, 1.5791023969650269, 3.7627718448638916],
            ],
            [
                [-0.27685487270355225, 6.446267127990723, -2.762883424758911],
                [-8.25644588470459, 9.616064071655273, 8.005367279052734],
                [-0.6944921016693115, 3.866114854812622, 4.788446426391602],
            ],
        ]
    ]
)
test_conv2d_with_bias_weight = np.array(
    [
        [
            [
                [1.8271433115005493, -1.0446699857711792, 1.0062190294265747],
                [0.5174201130867004, -0.806931734085083, 1.3769007921218872],
                [0.205885112285614, 0.9943519234657288, -0.23580588400363922],
            ]
        ],
        [
            [
                [0.29881811141967773, -1.9982075691223145, 0.3511354625225067],
                [-0.7644741535186768, 1.2594351768493652, -0.9629734754562378],
                [0.5080506205558777, 0.7561734318733215, 1.6839302778244019],
            ]
        ],
        [
            [
                [1.2573646306991577, 0.13123232126235962, 1.6403018236160278],
                [-1.2138012647628784, 2.399970531463623, -0.38509097695350647],
                [-0.9878040552139282, 0.9585888385772705, -1.4976465702056885],
            ]
        ],
    ]
)
test_conv2d_with_bias_bias = np.array(
    [0.6605162620544434, -0.18903568387031555, -0.27302607893943787]
)
test_conv2d_with_bias_data = np.array(
    [
        [
            [
                [
                    -0.47827261686325073,
                    -1.1739492416381836,
                    -0.7921845316886902,
                    0.9321041703224182,
                    -3.1557741165161133,
                ],
                [
                    2.1935296058654785,
                    -0.5385921001434326,
                    -0.8611332774162292,
                    -1.881519079208374,
                    -0.7205708026885986,
                ],
                [
                    -0.35601571202278137,
                    -0.15963983535766602,
                    1.797447681427002,
                    0.19594945013523102,
                    -1.7376397848129272,
                ],
                [
                    0.047347065061330795,
                    0.14580930769443512,
                    0.32604914903640747,
                    0.4578782916069031,
                    -0.8942581415176392,
                ],
                [
                    0.49383941292762756,
                    -0.9043426513671875,
                    -1.2140793800354004,
                    2.1564064025878906,
                    1.0938222408294678,
                ],
            ]
        ]
    ]
)
test_conv2d_with_bias_output = np.array(
    [
        [
            [
                [-0.05607491731643677, -0.185230553150177, -3.8808679580688477],
                [6.861937046051025, -2.3341472148895264, -0.5597308874130249],
                [1.8299254179000854, -2.770848274230957, 2.1958212852478027],
            ],
            [
                [2.9348952770233154, 4.117504119873047, -6.278541088104248],
                [0.2638452351093292, 3.998856782913208, 2.612290620803833],
                [-1.9891828298568726, -1.6476304531097412, 3.39066219329834],
            ],
            [
                [-8.44466781616211, 0.5747121572494507, -8.501373291015625],
                [-0.036642804741859436, -0.23458999395370483, -2.370849370956421],
                [2.8372013568878174, -2.987276077270508, 1.8382092714309692],
            ],
        ]
    ]
)

to_device = flow.device("cuda")

conv = flow.nn.Conv2d(1, 3, (3, 3), bias=True).to(to_device)
x = flow.tensor(test_conv2d_with_bias_data, dtype=flow.float32, device=to_device)
conv.weight = flow.nn.Parameter(flow.Tensor(test_conv2d_with_bias_weight))
conv.bias = flow.nn.Parameter(flow.Tensor(test_conv2d_with_bias_bias))
conv.to(to_device)
of_out = conv(x)
if np.allclose(of_out.numpy(), test_conv2d_with_bias_output, rtol=1e-4, atol=1e-8):
    print("conv2d_bias Passed")

conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device("cuda"))
x = flow.tensor(test_conv2d_data, dtype=flow.float32, device=to_device, requires_grad=True)
conv.weight = flow.nn.Parameter(flow.Tensor(test_conv2d_weight), requires_grad=True)
conv.to(to_device)
of_out = conv(x)
of_out.sum().backward()
if np.allclose(x.grad.numpy(), test_conv2d_data_grad, rtol=1e-4, atol=1e-8):
    print("con2d_back_data_grad Passed")

if np.allclose(conv.weight.grad.numpy(), test_conv2d_weight_grad, rtol=1e-4, atol=1e-8):
    print("con2d_back_weight_grad Passed")

conv = flow.nn.Conv2d(1, 3, (3, 3), bias=True).to(to_device)
x = flow.tensor(test_conv2d_with_bias_data, dtype=flow.float32, device=to_device)
conv.weight = flow.nn.Parameter(flow.Tensor(test_conv2d_with_bias_weight))
conv.bias = flow.nn.Parameter(flow.Tensor(test_conv2d_with_bias_bias))
conv.to(to_device)
of_out = conv(x)
if np.allclose(of_out.numpy(), test_conv2d_with_bias_output, rtol=1e-4, atol=1e-8):
    print("conv2d_bias Passed")

conv = flow.nn.Conv2d(1, 3, (3, 3), bias=False).to(flow.device("cuda"))
x = flow.tensor(test_conv2d_data, dtype=flow.float32, device=to_device, requires_grad=True)
conv.weight = flow.nn.Parameter(flow.Tensor(test_conv2d_weight), requires_grad=True)
conv.to(to_device)
of_out = conv(x)
of_out.sum().backward()
if np.allclose(x.grad.numpy(), test_conv2d_data_grad, rtol=1e-4, atol=1e-8):
    print("con2d_back_data_grad Passed")

if np.allclose(conv.weight.grad.numpy(), test_conv2d_weight_grad, rtol=1e-4, atol=1e-8):
    print("con2d_back_weight_grad Passed")






