Unverified Commit db92ee13 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #64 from ROCmSoftwarePlatform/IFU-master-2021-12-08

IFU-master-2021-12-08
parents d150afdc 68364b49
#pragma once #pragma once
//Philox CUDA. // Philox CUDA.
class Philox { class Philox {
public: public:
...@@ -15,28 +15,30 @@ public: ...@@ -15,28 +15,30 @@ public:
incr_n(offset / 4); incr_n(offset / 4);
} }
__device__ inline uint4 operator()() { __device__ inline uint4 operator()() {
if(STATE == 0) { if (STATE == 0) {
uint4 counter_ = counter; uint4 counter_ = counter;
uint2 key_ = key; uint2 key_ = key;
//7-round philox // 7-round philox
for(int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_); counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A); key_.y += (kPhilox10B); key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
} }
output = single_round(counter_, key_); output = single_round(counter_, key_);
incr(); incr();
} }
//return a float4 directly // return a float4 directly
//unsigned long ret; // unsigned long ret;
//switch(STATE) { // switch(STATE) {
// case 0: ret = output.x; break; // case 0: ret = output.x; break;
// case 1: ret = output.y; break; // case 1: ret = output.y; break;
// case 2: ret = output.z; break; // case 2: ret = output.z; break;
// case 3: ret = output.w; break; // case 3: ret = output.w; break;
//} //}
//STATE = (STATE + 1) % 4; // STATE = (STATE + 1) % 4;
return output; return output;
} }
private: private:
uint4 counter; uint4 counter;
uint4 output; uint4 output;
...@@ -67,7 +69,7 @@ private: ...@@ -67,7 +69,7 @@ private:
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b, __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) { unsigned int *result_high) {
*result_high = __umulhi(a, b); *result_high = __umulhi(a, b);
return a*b; return a * b;
} }
__device__ inline uint4 single_round(uint4 ctr, uint2 key) { __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
unsigned int hi0; unsigned int hi0;
...@@ -84,7 +86,7 @@ private: ...@@ -84,7 +86,7 @@ private:
}; };
// Inverse of 2^32. // Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f #define M_RAN_INVM32 2.3283064e-10f
__device__ __inline__ float4 uniform4(uint4 x) { __device__ __inline__ float4 uniform4(uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,x.w * M_RAN_INVM32); return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
} }
This diff is collapsed.
This diff is collapsed.
...@@ -2,4 +2,5 @@ from .fused_sgd import FusedSGD ...@@ -2,4 +2,5 @@ from .fused_sgd import FusedSGD
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_novograd import FusedNovoGrad from .fused_novograd import FusedNovoGrad
from .fused_lamb import FusedLAMB from .fused_lamb import FusedLAMB
from .fused_adagrad import FusedAdagrad from .fused_adagrad import FusedAdagrad
\ No newline at end of file from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler
from apex.transformer._data._batchsampler import MegatronPretrainingSampler
__all__ = [
"MegatronPretrainingRandomSampler",
"MegatronPretrainingSampler",
]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment