Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
531152d9
Commit
531152d9
authored
Mar 16, 2021
by
Vijay Korthikanti
Browse files
minor fixes
parent
b1a83375
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
10 deletions
+11
-10
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+3
-3
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+8
-7
No files found.
megatron/fused_kernels/scaled_masked_softmax.h
View file @
531152d9
...
@@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
...
@@ -32,11 +32,9 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
float
,
1
>
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
...
@@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -250,6 +248,8 @@ __global__ void scaled_masked_softmax_warp_backward(
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
531152d9
...
@@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
...
@@ -32,21 +32,22 @@ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
float
,
1
>
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
__half
,
4
>
(
__half
*
dst
)
{
*
((
float2
*
)
dst
)
=
0
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
__half
,
4
>
(
__half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -199,7 +200,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
}
else
if
(
element_index
<
element_count
)
{
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
)
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
)
;
}
else
{
}
else
{
break
;
break
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment